From ccddcbaccdd0c2bb50cbb2a4fb5de0fc9ade70d8 Mon Sep 17 00:00:00 2001 From: Daniel Trebbien Date: Sun, 23 Apr 2017 16:31:17 -0400 Subject: [PATCH 001/697] Delete unnecessary forward declarations These are made in tensorflow/stream_executor/stream_executor_internal.h Additionally, RngSupport is within perftools::gputools::rng rather than perftools::gputools::internal. --- tensorflow/stream_executor/cuda/cuda_gpu_executor.cc | 8 -------- tensorflow/stream_executor/cuda/cuda_gpu_executor.h | 11 ----------- 2 files changed, 19 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index 1bb90afd63..c1e72bb565 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -67,14 +67,6 @@ limitations under the License. extern bool FLAGS_check_gpu_leaks; bool FLAGS_prefer_cubin_to_ptx = true; -namespace perftools { -namespace gputools { -namespace rng { -class RngSupport; -} // namespace rng -} // namespace gputools -} // namespace perftools - namespace perftools { namespace gputools { namespace cuda { diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index 9d386b5ed9..6c5b9dca90 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -35,17 +35,6 @@ limitations under the License. #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/stream_executor_internal.h" -namespace perftools { -namespace gputools { -namespace blas { -class BlasSupport; -} -namespace internal { -class RngSupport; -} // namespace internal -} // namespace gputools -} // namespace perftools - namespace perftools { namespace gputools { namespace cuda { -- GitLab From f3ec1cc35f883f06ec917369d3229c61a1bd9b35 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 26 Apr 2017 00:29:19 -0400 Subject: [PATCH 002/697] add support for flat both inner and outer dims --- tensorflow/core/framework/tensor.cc | 44 +++---- tensorflow/core/framework/tensor.h | 38 +++++- tensorflow/core/framework/tensor_test.cc | 155 +++++++++++++++++++++++ 3 files changed, 203 insertions(+), 34 deletions(-) diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index ecb9810d83..b11934a8a2 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -902,43 +902,31 @@ void Tensor::FillDescription(TensorDescription* description) const { } gtl::InlinedVector Tensor::ComputeFlatInnerDims( - int64 num_out_dims) const { - if (num_out_dims == dims()) { - return shape_.dim_sizes(); + gtl::InlinedVector orig, int64 num_out_dims) { + if (num_out_dims == orig.size()) { + return orig; } gtl::InlinedVector out_dims(num_out_dims, 0); - const int64 num_elements = NumElements(); - int64 prod_out_dims = 1; - for (int64 out_dim = num_out_dims - 1; out_dim > 0; --out_dim) { - const int64 in_dim = out_dim + (dims() - num_out_dims); - out_dims[out_dim] = (in_dim >= dims() || in_dim < 0) ? 1 : dim_size(in_dim); - prod_out_dims *= out_dims[out_dim]; - } - if (prod_out_dims != 0) { - out_dims[0] = num_elements / prod_out_dims; - } else { - out_dims[0] = 0; + int64 offset = orig.size() - num_out_dims; + for (int64 out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) { + const int64 in_dim = out_dim + offset; + out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim]; } + for (int64 in_dim = 0; in_dim < offset; ++in_dim) + out_dims[0] *= orig[in_dim]; return out_dims; } gtl::InlinedVector Tensor::ComputeFlatOuterDims( - int64 num_out_dims) const { - if (num_out_dims == dims()) { - return shape_.dim_sizes(); + gtl::InlinedVector orig, int64 num_out_dims) { + if (num_out_dims == orig.size()) { + return orig; } gtl::InlinedVector out_dims(num_out_dims, 0); - const int64 num_elements = NumElements(); - int64 prod_out_dims = 1; - for (int64 out_dim = 0; out_dim < num_out_dims - 1; ++out_dim) { - out_dims[out_dim] = out_dim >= dims() ? 1 : dim_size(out_dim); - prod_out_dims *= out_dims[out_dim]; - } - if (prod_out_dims != 0) { - out_dims[num_out_dims - 1] = num_elements / prod_out_dims; - } else { - out_dims[num_out_dims - 1] = 0; - } + for (int64 out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) + out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim]; + for (int64 in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) + out_dims[num_out_dims - 1] *= orig[in_dim]; return out_dims; } diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 103da4c1b3..08508fab3b 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -304,6 +304,15 @@ class Tensor { template typename TTypes::Tensor flat_outer_dims(); + /// Returns the data as an Eigen::Tensor with END-START dimensions, collapsing + /// Tensor dimensions of the first START into the first dimension of the + /// result and the Tensor dimensions of the last dims()-END into the last dimension + /// of the result. If START < 0 then the the |START| leading dimensions of size 1 will be + /// added. If END > dims() then END - dims() trailing dimensions of size 1 will be + /// added. + template + typename TTypes::Tensor flat_inner_outer_dims(); + template typename TTypes::Tensor shaped(gtl::ArraySlice new_sizes); @@ -386,6 +395,9 @@ class Tensor { template typename TTypes::ConstTensor flat_outer_dims() const; + template + typename TTypes::Tensor flat_inner_outer_dims() const; + /// Render the first `max_entries` values in `*this` into a string. string SummarizeValue(int64 max_entries) const; @@ -431,8 +443,10 @@ class Tensor { // TODO(rmlarsen): These shouldn't hardcode '4' so that it lines up with // TensorShape's InlineVector. - gtl::InlinedVector ComputeFlatInnerDims(int64 num_out_dims) const; - gtl::InlinedVector ComputeFlatOuterDims(int64 num_out_dims) const; + static gtl::InlinedVector ComputeFlatInnerDims( + gtl::InlinedVector orig, int64 num_out_dims); + static gtl::InlinedVector ComputeFlatOuterDims( + gtl::InlinedVector orig, int64 num_out_dims); TensorShape shape_; TensorBuffer* buf_; @@ -638,22 +652,34 @@ typename TTypes::ConstScalar Tensor::scalar() const { template typename TTypes::Tensor Tensor::flat_inner_dims() { - return shaped(ComputeFlatInnerDims(NDIMS)); + return shaped(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS)); } template typename TTypes::Tensor Tensor::flat_outer_dims() { - return shaped(ComputeFlatOuterDims(NDIMS)); + return shaped(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS)); +} + +template +typename TTypes::Tensor Tensor::flat_inner_outer_dims() { + gtl::InlinedVector o = ComputeFlatOuterDims(shape_.dim_sizes(), END); + return shaped(ComputeFlatInnerDims(o, END-BEGIN)); } template typename TTypes::ConstTensor Tensor::flat_inner_dims() const { - return shaped(ComputeFlatInnerDims(NDIMS)); + return shaped(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS)); } template typename TTypes::ConstTensor Tensor::flat_outer_dims() const { - return shaped(ComputeFlatOuterDims(NDIMS)); + return shaped(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS)); +} + +template +typename TTypes::Tensor Tensor::flat_inner_outer_dims() const { + gtl::InlinedVector o = ComputeFlatOuterDims(shape_.dim_sizes(), END); + return shaped(ComputeFlatInnerDims(o, END-BEGIN)); } inline Tensor::Tensor(const Tensor& other) diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index c907bbb69f..8f6d397607 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -218,12 +218,14 @@ TEST(Tensor_Float, Reshape) { tensor(1, 2, 3, 4) = 0.02f; } { + LOG(INFO) << "shaped"; auto shaped = t.shaped({120}); EXPECT_EQ(120, shaped.dimension(0)); EXPECT_EQ(shaped(0), 0.01f); EXPECT_EQ(shaped(119), 0.02f); } { + LOG(INFO) << "shaped"; auto shaped = t.shaped({6, 20}); EXPECT_EQ(6, shaped.dimension(0)); EXPECT_EQ(20, shaped.dimension(1)); @@ -231,6 +233,7 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(shaped(5, 19), 0.02f); } { + LOG(INFO) << "shaped"; auto shaped = t.shaped({6, 4, 5}); EXPECT_EQ(6, shaped.dimension(0)); EXPECT_EQ(4, shaped.dimension(1)); @@ -239,6 +242,7 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(shaped(5, 3, 4), 0.02f); } { + LOG(INFO) << "shaped"; auto shaped = t.shaped({2, 3, 4, 5}); EXPECT_EQ(2, shaped.dimension(0)); EXPECT_EQ(3, shaped.dimension(1)); @@ -249,6 +253,7 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(shaped(1, 2, 3, 4), 0.02f); } { + LOG(INFO) << "flat"; auto flat = t.flat(); EXPECT_EQ(flat(0), 0.01f); EXPECT_EQ(120, flat.dimension(0)); @@ -256,6 +261,7 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(flat(119), 0.02f); } { + LOG(INFO) << "flat_inner_dims"; auto flat_inner_dims = t.flat_inner_dims(); EXPECT_EQ(24, flat_inner_dims.dimension(0)); EXPECT_EQ(5, flat_inner_dims.dimension(1)); @@ -263,6 +269,7 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(flat_inner_dims(23, 4), 0.02f); } { + LOG(INFO) << "flat_outer_dims"; auto flat_outer_dims = t.flat_outer_dims(); EXPECT_EQ(2, flat_outer_dims.dimension(0)); EXPECT_EQ(60, flat_outer_dims.dimension(1)); @@ -270,6 +277,7 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(flat_outer_dims(1, 59), 0.02f); } { + LOG(INFO) << "flat_inner_dims"; auto flat_inner_dims = t.flat_inner_dims(); EXPECT_EQ(6, flat_inner_dims.dimension(0)); EXPECT_EQ(4, flat_inner_dims.dimension(1)); @@ -278,6 +286,7 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(flat_inner_dims(5, 3, 4), 0.02f); } { + LOG(INFO) << "flat_outer_dims"; auto flat_outer_dims = t.flat_outer_dims(); EXPECT_EQ(2, flat_outer_dims.dimension(0)); EXPECT_EQ(3, flat_outer_dims.dimension(1)); @@ -286,6 +295,7 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(flat_outer_dims(1, 2, 19), 0.02f); } { + LOG(INFO) << "flat_inner_dims"; auto flat_inner_dims = t.flat_inner_dims(); EXPECT_EQ(1, flat_inner_dims.dimension(0)); EXPECT_EQ(2, flat_inner_dims.dimension(1)); @@ -296,6 +306,7 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(flat_inner_dims(0, 1, 2, 3, 4), 0.02f); } { + LOG(INFO) << "flat_outer_dims"; auto flat_outer_dims = t.flat_outer_dims(); EXPECT_EQ(2, flat_outer_dims.dimension(0)); EXPECT_EQ(3, flat_outer_dims.dimension(1)); @@ -305,20 +316,119 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(flat_outer_dims(0, 0, 0, 0, 0), 0.01f); EXPECT_EQ(flat_outer_dims(1, 2, 3, 4, 0), 0.02f); } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(4, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(1, 2, 3, 4), 0.02f); + } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(4, flat_inner_outer_dims.dimension(4)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(5)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 1, 2, 3, 4), 0.02f); + } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(4, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(4)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(5)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(1, 2, 3, 4, 0, 0), 0.02f); + } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(4, flat_inner_outer_dims.dimension(4)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(5)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(6)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(7)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0, 0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 1, 2, 3, 4, 0, 0), 0.02f); + } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + EXPECT_EQ(6, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(4, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(5, 3, 4), 0.02f); + } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + EXPECT_EQ(6, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(4, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(4)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(5, 3, 4, 0, 0), 0.02f); + } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(20, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(1, 2, 19), 0.02f); + } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(1, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(20, flat_inner_outer_dims.dimension(4)); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 0, 0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(0, 0, 1, 2, 19), 0.02f); + } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + EXPECT_EQ(6, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(20, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(flat_inner_outer_dims(0, 0), 0.01f); + EXPECT_EQ(flat_inner_outer_dims(5, 19), 0.02f); + } Tensor zero_t(DT_FLOAT, TensorShape({3, 0, 2, 0, 5})); { + LOG(INFO) << "flat_outer_dims"; auto flat_outer_dims = zero_t.flat_outer_dims(); EXPECT_EQ(3, flat_outer_dims.dimension(0)); EXPECT_EQ(0, flat_outer_dims.dimension(1)); } { + LOG(INFO) << "flat_outer_dims"; auto flat_outer_dims = zero_t.flat_outer_dims(); EXPECT_EQ(3, flat_outer_dims.dimension(0)); EXPECT_EQ(0, flat_outer_dims.dimension(1)); EXPECT_EQ(0, flat_outer_dims.dimension(2)); } { + LOG(INFO) << "flat_outer_dims"; auto flat_outer_dims = zero_t.flat_outer_dims(); EXPECT_EQ(3, flat_outer_dims.dimension(0)); EXPECT_EQ(0, flat_outer_dims.dimension(1)); @@ -327,17 +437,20 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(5, flat_outer_dims.dimension(4)); } { + LOG(INFO) << "flat_inner_dims"; auto flat_inner_dims = zero_t.flat_inner_dims(); EXPECT_EQ(0, flat_inner_dims.dimension(0)); EXPECT_EQ(5, flat_inner_dims.dimension(1)); } { + LOG(INFO) << "flat_inner_dims"; auto flat_inner_dims = zero_t.flat_inner_dims(); EXPECT_EQ(0, flat_inner_dims.dimension(0)); EXPECT_EQ(0, flat_inner_dims.dimension(1)); EXPECT_EQ(5, flat_inner_dims.dimension(2)); } { + LOG(INFO) << "flat_inner_dims"; auto flat_inner_dims = zero_t.flat_inner_dims(); EXPECT_EQ(3, flat_inner_dims.dimension(0)); EXPECT_EQ(0, flat_inner_dims.dimension(1)); @@ -345,6 +458,48 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(0, flat_inner_dims.dimension(3)); EXPECT_EQ(5, flat_inner_dims.dimension(4)); } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(1)); + } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(2)); + } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(); + EXPECT_EQ(3, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(2)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(3)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(4)); + } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(1)); + } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(5, flat_inner_outer_dims.dimension(2)); + } + { + LOG(INFO) << "flat_inner_outer_dims"; + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(0)); + EXPECT_EQ(2, flat_inner_outer_dims.dimension(1)); + EXPECT_EQ(0, flat_inner_outer_dims.dimension(2)); + } } TEST(Tensor_Scalar, Basics) { -- GitLab From 5cdce84e0c14c7a922d78d449b3d20bab46c5ddf Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 26 Apr 2017 02:00:48 -0400 Subject: [PATCH 003/697] use NDIMS as template and begin as argument --- tensorflow/core/framework/tensor.h | 36 ++++++++++++------------ tensorflow/core/framework/tensor_test.cc | 30 ++++++++++---------- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 08508fab3b..86500d5666 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -304,14 +304,14 @@ class Tensor { template typename TTypes::Tensor flat_outer_dims(); - /// Returns the data as an Eigen::Tensor with END-START dimensions, collapsing - /// Tensor dimensions of the first START into the first dimension of the - /// result and the Tensor dimensions of the last dims()-END into the last dimension - /// of the result. If START < 0 then the the |START| leading dimensions of size 1 will be - /// added. If END > dims() then END - dims() trailing dimensions of size 1 will be - /// added. - template - typename TTypes::Tensor flat_inner_outer_dims(); + /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing the + /// first 'begin' Tensor dimensions into the first dimension of the result and + /// the Tensor dimensions of the last dims() - 'begin' - NDIMS into the last + /// dimension of the result. If 'begin' < 0 then the the |'begin'| leading + /// dimensions of size 1 will be added. If 'begin' + NDIMS > dims() then + /// 'begin' + NDIMS - dims() trailing dimensions of size 1 will be added. + template + typename TTypes::Tensor flat_inner_outer_dims(int64 begin); template typename TTypes::Tensor shaped(gtl::ArraySlice new_sizes); @@ -395,8 +395,8 @@ class Tensor { template typename TTypes::ConstTensor flat_outer_dims() const; - template - typename TTypes::Tensor flat_inner_outer_dims() const; + template + typename TTypes::Tensor flat_inner_outer_dims(int64 begin) const; /// Render the first `max_entries` values in `*this` into a string. string SummarizeValue(int64 max_entries) const; @@ -660,10 +660,10 @@ typename TTypes::Tensor Tensor::flat_outer_dims() { return shaped(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS)); } -template -typename TTypes::Tensor Tensor::flat_inner_outer_dims() { - gtl::InlinedVector o = ComputeFlatOuterDims(shape_.dim_sizes(), END); - return shaped(ComputeFlatInnerDims(o, END-BEGIN)); +template +typename TTypes::Tensor Tensor::flat_inner_outer_dims(int64 begin) { + gtl::InlinedVector o = ComputeFlatOuterDims(shape_.dim_sizes(), begin+NDIMS); + return shaped(ComputeFlatInnerDims(o, NDIMS)); } template @@ -676,10 +676,10 @@ typename TTypes::ConstTensor Tensor::flat_outer_dims() const { return shaped(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS)); } -template -typename TTypes::Tensor Tensor::flat_inner_outer_dims() const { - gtl::InlinedVector o = ComputeFlatOuterDims(shape_.dim_sizes(), END); - return shaped(ComputeFlatInnerDims(o, END-BEGIN)); +template +typename TTypes::Tensor Tensor::flat_inner_outer_dims(int64 begin) const { + gtl::InlinedVector o = ComputeFlatOuterDims(shape_.dim_sizes(), begin+NDIMS); + return shaped(ComputeFlatInnerDims(o, NDIMS)); } inline Tensor::Tensor(const Tensor& other) diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index 8f6d397607..3e4179101f 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -318,7 +318,7 @@ TEST(Tensor_Float, Reshape) { } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = t.flat_inner_outer_dims(0); EXPECT_EQ(2, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(3, flat_inner_outer_dims.dimension(1)); EXPECT_EQ(4, flat_inner_outer_dims.dimension(2)); @@ -328,7 +328,7 @@ TEST(Tensor_Float, Reshape) { } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = t.flat_inner_outer_dims(-2); EXPECT_EQ(1, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(1, flat_inner_outer_dims.dimension(1)); EXPECT_EQ(2, flat_inner_outer_dims.dimension(2)); @@ -340,7 +340,7 @@ TEST(Tensor_Float, Reshape) { } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = t.flat_inner_outer_dims(0); EXPECT_EQ(2, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(3, flat_inner_outer_dims.dimension(1)); EXPECT_EQ(4, flat_inner_outer_dims.dimension(2)); @@ -352,7 +352,7 @@ TEST(Tensor_Float, Reshape) { } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = t.flat_inner_outer_dims(-2); EXPECT_EQ(1, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(1, flat_inner_outer_dims.dimension(1)); EXPECT_EQ(2, flat_inner_outer_dims.dimension(2)); @@ -366,7 +366,7 @@ TEST(Tensor_Float, Reshape) { } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = t.flat_inner_outer_dims(1); EXPECT_EQ(6, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(4, flat_inner_outer_dims.dimension(1)); EXPECT_EQ(5, flat_inner_outer_dims.dimension(2)); @@ -375,7 +375,7 @@ TEST(Tensor_Float, Reshape) { } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = t.flat_inner_outer_dims(1); EXPECT_EQ(6, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(4, flat_inner_outer_dims.dimension(1)); EXPECT_EQ(5, flat_inner_outer_dims.dimension(2)); @@ -386,7 +386,7 @@ TEST(Tensor_Float, Reshape) { } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = t.flat_inner_outer_dims(0); EXPECT_EQ(2, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(3, flat_inner_outer_dims.dimension(1)); EXPECT_EQ(20, flat_inner_outer_dims.dimension(2)); @@ -395,7 +395,7 @@ TEST(Tensor_Float, Reshape) { } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = t.flat_inner_outer_dims(-2); EXPECT_EQ(1, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(1, flat_inner_outer_dims.dimension(1)); EXPECT_EQ(2, flat_inner_outer_dims.dimension(2)); @@ -406,7 +406,7 @@ TEST(Tensor_Float, Reshape) { } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = t.flat_inner_outer_dims(1); EXPECT_EQ(6, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(20, flat_inner_outer_dims.dimension(1)); EXPECT_EQ(flat_inner_outer_dims(0, 0), 0.01f); @@ -460,20 +460,20 @@ TEST(Tensor_Float, Reshape) { } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(0); EXPECT_EQ(3, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(0, flat_inner_outer_dims.dimension(1)); } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(0); EXPECT_EQ(3, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(0, flat_inner_outer_dims.dimension(1)); EXPECT_EQ(0, flat_inner_outer_dims.dimension(2)); } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(0); EXPECT_EQ(3, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(0, flat_inner_outer_dims.dimension(1)); EXPECT_EQ(2, flat_inner_outer_dims.dimension(2)); @@ -482,20 +482,20 @@ TEST(Tensor_Float, Reshape) { } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(3); EXPECT_EQ(0, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(5, flat_inner_outer_dims.dimension(1)); } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(2); EXPECT_EQ(0, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(0, flat_inner_outer_dims.dimension(1)); EXPECT_EQ(5, flat_inner_outer_dims.dimension(2)); } { LOG(INFO) << "flat_inner_outer_dims"; - auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(); + auto flat_inner_outer_dims = zero_t.flat_inner_outer_dims(1); EXPECT_EQ(0, flat_inner_outer_dims.dimension(0)); EXPECT_EQ(2, flat_inner_outer_dims.dimension(1)); EXPECT_EQ(0, flat_inner_outer_dims.dimension(2)); -- GitLab From 377eca1517f783ca171698aa056a7329d48e88b8 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 26 Apr 2017 08:05:47 -0800 Subject: [PATCH 004/697] Improved support for lookup tables. These tables are often initialized from large files stored on disk, and can take tens of minutes to load: increased the initialization timeout to give grappler enough time to load them. Change: 154303608 --- .../core/grappler/clusters/single_machine.cc | 19 +++++++++++++++---- .../core/grappler/clusters/single_machine.h | 4 ++++ tensorflow/core/grappler/grappler_item.h | 2 ++ .../core/grappler/grappler_item_builder.cc | 15 ++++++++++----- 4 files changed, 31 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc index 219d95acdf..09c8d55efd 100644 --- a/tensorflow/core/grappler/clusters/single_machine.cc +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -31,6 +31,7 @@ namespace grappler { SingleMachine::SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus) : Cluster(timeout_s), num_gpus_(num_gpus), + expected_init_time_s_(0), closing_(false) { thread_pool_.reset(new thread::ThreadPool( Env::Default(), SanitizeThreadSuffix("single_machine"), 2)); @@ -82,6 +83,7 @@ Status SingleMachine::Initialize(const GrapplerItem& item) { mutex_lock l(this->last_graph_mu_); if (last_graph_ != &item.graph || last_graph_id_ != item.id) { init_ops_ = item.init_ops; + expected_init_time_s_ = item.expected_init_time; last_graph_ = nullptr; queue_runner_defs_ = item.queue_runners; last_graph_id_ = item.id; @@ -100,7 +102,9 @@ Status SingleMachine::Run(const GraphDef& graph_def, TF_RETURN_IF_ERROR(session_->Create(graph_def)); if (!init_ops_.empty()) { init_metadata_ = RunMetadata(); - TF_RETURN_IF_ERROR(RunWithTimeout({}, init_ops_, &init_metadata_)); + int64 timeout_s = timeout_s_ + expected_init_time_s_; + TF_RETURN_IF_ERROR( + RunWithTimeout({}, init_ops_, &init_metadata_, timeout_s)); // The compute cost for init ops is likely to be pessimistic since init // ops are run only once before warmup. Therefore we only keep their // memory costs. @@ -143,6 +147,13 @@ Status SingleMachine::Run(const GraphDef& graph_def, Status SingleMachine::RunWithTimeout( const std::vector>& feed, const std::vector& fetch, RunMetadata* run_metadata) { + return RunWithTimeout(feed, fetch, run_metadata, timeout_s_); +} + +Status SingleMachine::RunWithTimeout( + const std::vector>& feed, + const std::vector& fetch, RunMetadata* run_metadata, + int64 timeout_s) { // We shouldn't be running or closing the session at this point. { mutex_lock l(close_mu_); @@ -155,10 +166,10 @@ Status SingleMachine::RunWithTimeout( *status = session_->Run(run_options_, feed, {}, fetch, nullptr, local_metadata.get()); }, - timeout_s_ * 1000, thread_pool_.get()); + timeout_s * 1000, thread_pool_.get()); if (!executed_in_time) { - return errors::DeadlineExceeded("Failed to run the graph after ", - timeout_s_, " seconds, aborting"); + return errors::DeadlineExceeded("Failed to run the graph after ", timeout_s, + " seconds, aborting"); } else if (run_metadata && status->ok()) { *run_metadata = *local_metadata; } diff --git a/tensorflow/core/grappler/clusters/single_machine.h b/tensorflow/core/grappler/clusters/single_machine.h index b739b39f2c..f69b11df5d 100644 --- a/tensorflow/core/grappler/clusters/single_machine.h +++ b/tensorflow/core/grappler/clusters/single_machine.h @@ -42,6 +42,9 @@ class SingleMachine : public Cluster { Status RunWithTimeout(const std::vector>& feed, const std::vector& fetch, RunMetadata* run_metadata); + Status RunWithTimeout(const std::vector>& feed, + const std::vector& fetch, + RunMetadata* run_metadata, int64 timeout_s); Status ResetSession(); Status CloseSession(bool use_timeout); @@ -52,6 +55,7 @@ class SingleMachine : public Cluster { mutex last_graph_mu_; const GraphDef* last_graph_ GUARDED_BY(last_graph_mu_) = nullptr; std::vector init_ops_; + int64 expected_init_time_s_; std::unique_ptr coordinator_; std::unique_ptr thread_pool_; diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h index cb21ae54f0..e0709c682b 100644 --- a/tensorflow/core/grappler/grappler_item.h +++ b/tensorflow/core/grappler/grappler_item.h @@ -42,6 +42,8 @@ struct GrapplerItem { // Initialization op(s). std::vector init_ops; + // Expected initialization time in seconds, or 0 if unknown + int64 expected_init_time = 0; // Queue runner(s) required to run the queue(s) of this model. std::vector queue_runners; diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index e37b908fc6..88799ba881 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -86,11 +86,6 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( } for (auto& node : *new_item->graph.mutable_node()) { - // Delete user specified placement if requested. - if (cfg.ignore_user_placement) { - node.clear_device(); - } - if (IsPlaceholder(node)) { if (node.attr().count("dtype") == 0) { LOG(ERROR) << "Unknown type for placeholder " << node.name() @@ -142,6 +137,11 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( new_item->feed.emplace_back(node.name(), fake_input); } + // Delete user specified placement if requested. + if (cfg.ignore_user_placement) { + node.clear_device(); + } + // Delete colocation constraints if requested. if (cfg.ignore_colocation) { auto attr = node.mutable_attr(); auto it = attr->find("_class"); @@ -173,6 +173,11 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( if (inits.has_node_list()) { for (const auto& node : inits.node_list().value()) { new_item->init_ops.push_back(node); + // Tables are initialized from files, which can take a long time. Add 30 + // minutes to the initialization time for each table to avoid timing + // out. + // TODO(bsteiner): adjust the timeout based on the file size. + new_item->expected_init_time += 30 * 60; } } } -- GitLab From cfbeafe11d9b86f8685c1c0f97d285885b5a5f1f Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Wed, 26 Apr 2017 08:13:38 -0800 Subject: [PATCH 005/697] Add eager (static) checking to assert_equal when constant values are available. This raises errors during graph creation, when possible, instead of at runtime. Change: 154304392 --- .../python/kernel_tests/binomial_test.py | 7 ++- .../dirichlet_multinomial_test.py | 4 +- .../kernel_tests/distribution_util_test.py | 44 +++++++------ .../python/kernel_tests/geometric_test.py | 7 ++- .../python/kernel_tests/multinomial_test.py | 4 +- .../kernel_tests/negative_binomial_test.py | 8 ++- .../python/kernel_tests/poisson_test.py | 9 ++- .../quantized_distribution_test.py | 5 +- .../python/kernel_tests/check_ops_test.py | 24 ++++++-- tensorflow/python/ops/check_ops.py | 61 +++++++++++++------ 10 files changed, 116 insertions(+), 57 deletions(-) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py index c473d54f47..d30f6e418d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py @@ -19,7 +19,9 @@ from __future__ import print_function import numpy as np from scipy import stats from tensorflow.contrib.distributions.python.ops import binomial +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -97,13 +99,14 @@ class BinomialTest(test.TestCase): binom.prob([3., 1, 2]).eval() binom.cdf([2., 3, 2]).eval() binom.cdf([3., 1, 2]).eval() + placeholder = array_ops.placeholder(dtypes.float32) # Both equality and integer checking fail. with self.assertRaisesOpError( "cannot contain fractional components."): - binom.prob([1.0, 2.5, 1.5]).eval() + binom.prob(placeholder).eval(feed_dict={placeholder: [1.0, 2.5, 1.5]}) with self.assertRaisesOpError( "cannot contain fractional components."): - binom.cdf([1.0, 2.5, 1.5]).eval() + binom.cdf(placeholder).eval(feed_dict={placeholder: [1.0, 2.5, 1.5]}) binom = binomial.Binomial(total_count=n, probs=p, validate_args=False) binom.prob([1., 2., 3.]).eval() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py index 54691d2095..bc25366cfa 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py @@ -18,6 +18,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib import distributions +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -87,9 +88,10 @@ class DirichletMultinomialTest(test.TestCase): dist.prob([3., 0, 2]).eval() dist.prob([3.0, 0, 2.0]).eval() # Both equality and integer checking fail. + placeholder = array_ops.placeholder(dtypes.float32) with self.assertRaisesOpError( "counts cannot contain fractional components"): - dist.prob([1.0, 2.5, 1.5]).eval() + dist.prob(placeholder).eval(feed_dict={placeholder: [1.0, 2.5, 1.5]}) dist = ds.DirichletMultinomial(n, alpha, validate_args=False) dist.prob([1., 2., 3.]).eval() # Non-integer arguments work. diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index 5a4a6720f7..2b28392d35 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -41,42 +41,44 @@ from tensorflow.python.platform import tf_logging as logging class AssertCloseTest(test.TestCase): def testAssertCloseIntegerDtype(self): - x = [1, 5, 10, 15, 20] + x = array_ops.placeholder(dtypes.int32) y = x - z = [2, 5, 10, 15, 20] + z = array_ops.placeholder(dtypes.int32) + feed_dict = {x: [1, 5, 10, 15, 20], z: [2, 5, 10, 15, 20]} with self.test_session(): with ops.control_dependencies([distribution_util.assert_close(x, y)]): - array_ops.identity(x).eval() + array_ops.identity(x).eval(feed_dict=feed_dict) with ops.control_dependencies([distribution_util.assert_close(y, x)]): - array_ops.identity(x).eval() + array_ops.identity(x).eval(feed_dict=feed_dict) with self.assertRaisesOpError("Condition x ~= y"): with ops.control_dependencies([distribution_util.assert_close(x, z)]): - array_ops.identity(x).eval() + array_ops.identity(x).eval(feed_dict=feed_dict) with self.assertRaisesOpError("Condition x ~= y"): with ops.control_dependencies([distribution_util.assert_close(y, z)]): - array_ops.identity(y).eval() + array_ops.identity(y).eval(feed_dict=feed_dict) def testAssertCloseNonIntegerDtype(self): - x = np.array([1., 5, 10, 15, 20], dtype=np.float32) + x = array_ops.placeholder(dtypes.float32) y = x + 1e-8 - z = [2., 5, 10, 15, 20] + z = array_ops.placeholder(dtypes.float32) + feed_dict = {x: [1., 5, 10, 15, 20], z: [2., 5, 10, 15, 20]} with self.test_session(): with ops.control_dependencies([distribution_util.assert_close(x, y)]): - array_ops.identity(x).eval() + array_ops.identity(x).eval(feed_dict=feed_dict) with ops.control_dependencies([distribution_util.assert_close(y, x)]): - array_ops.identity(x).eval() + array_ops.identity(x).eval(feed_dict=feed_dict) with self.assertRaisesOpError("Condition x ~= y"): with ops.control_dependencies([distribution_util.assert_close(x, z)]): - array_ops.identity(x).eval() + array_ops.identity(x).eval(feed_dict=feed_dict) with self.assertRaisesOpError("Condition x ~= y"): with ops.control_dependencies([distribution_util.assert_close(y, z)]): - array_ops.identity(y).eval() + array_ops.identity(y).eval(feed_dict=feed_dict) def testAssertCloseEpsilon(self): x = [0., 5, 10, 15, 20] @@ -98,30 +100,32 @@ class AssertCloseTest(test.TestCase): def testAssertIntegerForm(self): # This should only be detected as an integer. - x = [1., 5, 10, 15, 20] - y = [1.1, 5, 10, 15, 20] + x = array_ops.placeholder(dtypes.float32) + y = array_ops.placeholder(dtypes.float32) # First component isn't less than float32.eps = 1e-7 - z = [1.0001, 5, 10, 15, 20] + z = array_ops.placeholder(dtypes.float32) # This shouldn"t be detected as an integer. - w = [1e-8, 5, 10, 15, 20] + w = array_ops.placeholder(dtypes.float32) + feed_dict = {x: [1., 5, 10, 15, 20], y: [1.1, 5, 10, 15, 20], + z: [1.0001, 5, 10, 15, 20], w: [1e-8, 5, 10, 15, 20]} with self.test_session(): with ops.control_dependencies([distribution_util.assert_integer_form(x)]): - array_ops.identity(x).eval() + array_ops.identity(x).eval(feed_dict=feed_dict) with self.assertRaisesOpError("x has non-integer components"): with ops.control_dependencies( [distribution_util.assert_integer_form(y)]): - array_ops.identity(y).eval() + array_ops.identity(y).eval(feed_dict=feed_dict) with self.assertRaisesOpError("x has non-integer components"): with ops.control_dependencies( [distribution_util.assert_integer_form(z)]): - array_ops.identity(z).eval() + array_ops.identity(z).eval(feed_dict=feed_dict) with self.assertRaisesOpError("x has non-integer components"): with ops.control_dependencies( [distribution_util.assert_integer_form(w)]): - array_ops.identity(w).eval() + array_ops.identity(w).eval(feed_dict=feed_dict) class GetLogitsAndProbsTest(test.TestCase): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py b/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py index 3dbad7b607..9ef68c4c2c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py @@ -22,7 +22,9 @@ import numpy as np from scipy import stats from tensorflow.contrib.distributions.python.ops import geometric from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -74,12 +76,13 @@ class GeometricTest(test.TestCase): with self.test_session(): batch_size = 6 probs = constant_op.constant([.9] * batch_size) - x = np.array([2.5, 3.2, 4.3, 5.1, 6., 7.], dtype=np.float32) + x = array_ops.placeholder(dtypes.float32, shape=[6]) + feed_dict = {x: [2.5, 3.2, 4.3, 5.1, 6., 7.]} geom = geometric.Geometric(probs=probs) with self.assertRaisesOpError("Condition x == y"): log_prob = geom.log_prob(x) - log_prob.eval() + log_prob.eval(feed_dict=feed_dict) with self.assertRaisesOpError("Condition x >= 0"): log_prob = geom.log_prob(np.array([-1.], dtype=np.float32)) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/multinomial_test.py index 169498be24..b1c0c9f7a9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/multinomial_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/multinomial_test.py @@ -18,6 +18,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib import distributions +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -100,9 +101,10 @@ class MultinomialTest(test.TestCase): with self.assertRaisesOpError("counts must sum to `self.total_count`"): multinom.prob([2., 3, 2]).eval() # Counts are non-integers. + x = array_ops.placeholder(dtypes.float32) with self.assertRaisesOpError( "cannot contain fractional components."): - multinom.prob([1.0, 2.5, 1.5]).eval() + multinom.prob(x).eval(feed_dict={x: [1.0, 2.5, 1.5]}) multinom = ds.Multinomial(total_count=n, probs=p, validate_args=False) multinom.prob([1., 2., 2.]).eval() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py index f55de99396..c1a74c6483 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py @@ -21,6 +21,7 @@ import numpy as np from scipy import stats from tensorflow.contrib.distributions.python.ops import negative_binomial from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -140,17 +141,18 @@ class NegativeBinomialTest(test.TestCase): batch_size = 6 probs = [.9] * batch_size total_count = 5. - x = np.array([2.5, 3.2, 4.3, 5.1, 6., 7.], dtype=np.float32) + x = array_ops.placeholder(dtypes.float32, shape=[6]) + feed_dict = {x: [2.5, 3.2, 4.3, 5.1, 6., 7.]} negbinom = negative_binomial.NegativeBinomial( total_count=total_count, probs=probs, validate_args=True) with self.assertRaisesOpError("Condition x == y"): log_pmf = negbinom.log_prob(x) - log_pmf.eval() + log_pmf.eval(feed_dict=feed_dict) with self.assertRaisesOpError("Condition x >= 0"): log_pmf = negbinom.log_prob([-1.]) - log_pmf.eval() + log_pmf.eval(feed_dict=feed_dict) negbinom = negative_binomial.NegativeBinomial( total_count=total_count, probs=probs, validate_args=False) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py index b1a9478b43..f157c0d3ed 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py @@ -21,7 +21,9 @@ import numpy as np from scipy import stats from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -64,17 +66,18 @@ class PoissonTest(test.TestCase): with self.test_session(): batch_size = 6 lam = constant_op.constant([3.0] * batch_size) - x = [2.5, 3.2, 4.3, 5.1, 6., 7.] + x = array_ops.placeholder(dtypes.float32, shape=[6]) + feed_dict = {x: [2.5, 3.2, 4.3, 5.1, 6., 7.]} poisson = poisson_lib.Poisson(rate=lam, validate_args=True) # Non-integer with self.assertRaisesOpError("cannot contain fractional components"): log_pmf = poisson.log_prob(x) - log_pmf.eval() + log_pmf.eval(feed_dict=feed_dict) with self.assertRaisesOpError("Condition x >= 0"): log_pmf = poisson.log_prob([-1.]) - log_pmf.eval() + log_pmf.eval(feed_dict=feed_dict) poisson = poisson_lib.Poisson(rate=lam, validate_args=False) log_pmf = poisson.log_prob(x) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py index 0e2d143732..6a7ee3a8bf 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py @@ -373,15 +373,16 @@ class QuantizedDistributionTest(test.TestCase): def testCutoffsMustBeIntegerValuedIfValidateArgsTrue(self): with self.test_session(): + low = array_ops.placeholder(dtypes.float32) qdist = distributions.QuantizedDistribution( distribution=distributions.Normal(loc=0., scale=1.), - low=1.5, + low=low, high=10., validate_args=True) self.assertTrue(qdist.validate_args) # Default is True. with self.assertRaisesOpError("has non-integer components"): - qdist.sample().eval() + qdist.sample().eval(feed_dict={low: 1.5}) def testCutoffsCanBeFloatValuedIfValidateArgsFalse(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index d688b16478..ed859e3774 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -80,23 +80,35 @@ class AssertEqualTest(test.TestCase): def test_raises_when_greater(self): with self.test_session(): - small = constant_op.constant([1, 2], name="small") - big = constant_op.constant([3, 4], name="big") + # Static check + static_small = constant_op.constant([1, 2], name="small") + static_big = constant_op.constant([3, 4], name="big") + with self.assertRaisesRegexp(ValueError, "fail"): + check_ops.assert_equal(static_big, static_small, message="fail") + # Dynamic check + small = array_ops.placeholder(dtypes.int32, name="small") + big = array_ops.placeholder(dtypes.int32, name="big") with ops.control_dependencies( [check_ops.assert_equal( big, small, message="fail")]): out = array_ops.identity(small) with self.assertRaisesOpError("fail.*big.*small"): - out.eval() + out.eval(feed_dict={small: [1, 2], big: [3, 4]}) def test_raises_when_less(self): with self.test_session(): - small = constant_op.constant([3, 1], name="small") - big = constant_op.constant([4, 2], name="big") + # Static check + static_small = constant_op.constant([3, 1], name="small") + static_big = constant_op.constant([4, 2], name="big") + with self.assertRaisesRegexp(ValueError, "fail"): + check_ops.assert_equal(static_big, static_small, message="fail") + # Dynamic check + small = array_ops.placeholder(dtypes.int32, name="small") + big = array_ops.placeholder(dtypes.int32, name="big") with ops.control_dependencies([check_ops.assert_equal(small, big)]): out = array_ops.identity(small) with self.assertRaisesOpError("small.*big"): - out.eval() + out.eval(feed_dict={small: [3, 1], big: [4, 2]}) def test_doesnt_raise_when_equal_and_broadcastable_shapes(self): with self.test_session(): diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 8401f5493b..753999a672 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -84,6 +84,22 @@ __all__ = [ ] +def _maybe_constant_value_string(t): + if not isinstance(t, ops.Tensor): + return str(t) + const_t = tensor_util.constant_value(t) + if const_t is not None: + return str(const_t) + return t + + +def _assert_static(condition, data): + """Raises a static ValueError with as much information as possible.""" + if not condition: + data_static = [_maybe_constant_value_string(x) for x in data] + raise ValueError('\n'.join(data_static)) + + def assert_proper_iterable(values): """Static assert that values is a "proper" iterable. @@ -140,7 +156,9 @@ def assert_negative(x, data=None, summarize=None, message=None, name=None): x = ops.convert_to_tensor(x, name='x') if data is None: data = [ - message, 'Condition x < 0 did not hold element-wise: x = ', x.name, x] + message, + 'Condition x < 0 did not hold element-wise:', + 'x (%s) = ' % x.name, x] zero = ops.convert_to_tensor(0, dtype=x.dtype) return assert_less(x, zero, data=data, summarize=summarize) @@ -174,7 +192,8 @@ def assert_positive(x, data=None, summarize=None, message=None, name=None): x = ops.convert_to_tensor(x, name='x') if data is None: data = [ - message, 'Condition x > 0 did not hold element-wise: x = ', x.name, x] + message, 'Condition x > 0 did not hold element-wise:', + 'x (%s) = ' % x.name, x] zero = ops.convert_to_tensor(0, dtype=x.dtype) return assert_less(zero, x, data=data, summarize=summarize) @@ -210,7 +229,8 @@ def assert_non_negative(x, data=None, summarize=None, message=None, name=None): if data is None: data = [ message, - 'Condition x >= 0 did not hold element-wise: x = ', x.name, x] + 'Condition x >= 0 did not hold element-wise:', + 'x (%s) = ' % x.name, x] zero = ops.convert_to_tensor(0, dtype=x.dtype) return assert_less_equal(zero, x, data=data, summarize=summarize) @@ -246,7 +266,8 @@ def assert_non_positive(x, data=None, summarize=None, message=None, name=None): if data is None: data = [ message, - 'Condition x <= 0 did not hold element-wise: x = ', x.name, x] + 'Condition x <= 0 did not hold element-wise:' + 'x (%s) = ' % x.name, x] zero = ops.convert_to_tensor(0, dtype=x.dtype) return assert_less_equal(x, zero, data=data, summarize=summarize) @@ -284,10 +305,16 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): if data is None: data = [ message, - 'Condition x == y did not hold element-wise: x = ', x.name, x, 'y = ', - y.name, y + 'Condition x == y did not hold element-wise:', + 'x (%s) = ' % x.name, x, + 'y (%s) = ' % y.name, y ] condition = math_ops.reduce_all(math_ops.equal(x, y)) + x_static = tensor_util.constant_value(x) + y_static = tensor_util.constant_value(y) + if x_static is not None and y_static is not None: + condition_static = (x_static == y_static).all() + _assert_static(condition_static, data) return control_flow_ops.Assert(condition, data, summarize=summarize) @@ -326,9 +353,9 @@ def assert_none_equal( if data is None: data = [ message, - 'Condition x != y did not hold for every single element: x = ', - x.name, x, - 'y = ', y.name, y + 'Condition x != y did not hold for every single element:' + 'x (%s) = ' % x.name, x, + 'y (%s) = ' % y.name, y ] condition = math_ops.reduce_all(math_ops.not_equal(x, y)) return control_flow_ops.Assert(condition, data, summarize=summarize) @@ -367,8 +394,8 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None): if data is None: data = [ message, - 'Condition x < y did not hold element-wise: x = ', x.name, x, 'y = ', - y.name, y + 'Condition x < y did not hold element-wise:' + 'x (%s) = ' % x.name, x, 'y (%s) = ' % y.name, y ] condition = math_ops.reduce_all(math_ops.less(x, y)) return control_flow_ops.Assert(condition, data, summarize=summarize) @@ -407,8 +434,8 @@ def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): if data is None: data = [ message, - 'Condition x <= y did not hold element-wise: x = ', x.name, x, 'y = ', - y.name, y + 'Condition x <= y did not hold element-wise:' + 'x (%s) = ' % x.name, x, 'y (%s) = ' % y.name, y ] condition = math_ops.reduce_all(math_ops.less_equal(x, y)) return control_flow_ops.Assert(condition, data, summarize=summarize) @@ -447,8 +474,8 @@ def assert_greater(x, y, data=None, summarize=None, message=None, name=None): if data is None: data = [ message, - 'Condition x > y did not hold element-wise: x = ', x.name, x, 'y = ', - y.name, y + 'Condition x > y did not hold element-wise:' + 'x (%s) = ' % x.name, x, 'y (%s) = ' % y.name, y ] condition = math_ops.reduce_all(math_ops.greater(x, y)) return control_flow_ops.Assert(condition, data, summarize=summarize) @@ -489,8 +516,8 @@ def assert_greater_equal(x, y, data=None, summarize=None, message=None, if data is None: data = [ message, - 'Condition x >= y did not hold element-wise: x = ', x.name, x, 'y = ', - y.name, y + 'Condition x >= y did not hold element-wise:' + 'x (%s) = ' % x.name, x, 'y (%s) = ' % y.name, y ] condition = math_ops.reduce_all(math_ops.greater_equal(x, y)) return control_flow_ops.Assert(condition, data, summarize=summarize) -- GitLab From 9845d0e822f1a8ec455507f48541b379c9778829 Mon Sep 17 00:00:00 2001 From: Geoffrey Irving Date: Wed, 26 Apr 2017 08:53:36 -0800 Subject: [PATCH 006/697] Add tf.log_sigmoid This is a numerically stable version of tf.log(tf.sigmoid(x)). It's just -tf.nn.softplus(-x), but it's easy to add and the identity is easy to mistype. RELNOTES: Add tf.log_sigmoid(x) = tf.log(tf.sigmoid(x)) = -tf.nn.softplus(-x). Fixes #3719. Change: 154308666 --- tensorflow/python/BUILD | 1 + tensorflow/python/kernel_tests/BUILD | 1 + .../python/kernel_tests/cwise_ops_test.py | 5 +++++ tensorflow/python/ops/math_ops.py | 19 +++++++++++++++++++ tensorflow/python/ops/nn.py | 1 + tensorflow/python/ops/standard_ops.py | 1 + tensorflow/tools/api/golden/tensorflow.pbtxt | 4 ++++ 7 files changed, 32 insertions(+) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 15845506d9..2a61735bb5 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1499,6 +1499,7 @@ py_library( ":framework_ops", ":graph_util", ":math_ops_gen", + ":nn_ops_gen", ":sparse_ops_gen", ":sparse_tensor", ":spectral_ops_gen", diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index cdb15e0a91..06a0aa468a 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2347,6 +2347,7 @@ cuda_py_test( "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", + "//tensorflow/python:nn_grad", "//tensorflow/python:platform", "//tensorflow/python:variables", ], diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 9cd42d8851..cd0d33ecf3 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_grad # pylint: disable=unused-import from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -165,6 +166,9 @@ class UnaryOpTest(test.TestCase): def _sigmoid(self, x): return 1.0 / (1.0 + np.exp(-x)) + def _log_sigmoid(self, x): + return np.log(self._sigmoid(x)) + def _replace_domain_error_with_inf(self, fn): def func(x): @@ -198,6 +202,7 @@ class UnaryOpTest(test.TestCase): self._compareBoth(z, np.log1p, math_ops.log1p) self._compareBoth(x, np.tanh, math_ops.tanh) self._compareBoth(x, self._sigmoid, math_ops.sigmoid) + self._compareBoth(x, self._log_sigmoid, math_ops.log_sigmoid) self._compareBoth(y, np.sign, math_ops.sign) self._compareBoth(x, np.sin, math_ops.sin) self._compareBoth(x, np.cos, math_ops.cos) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 36a4184950..5f4ce63f31 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -151,6 +151,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gen_sparse_ops from tensorflow.python.ops import gen_spectral_ops from tensorflow.python.ops import gen_state_ops @@ -2004,6 +2005,24 @@ def sigmoid(x, name=None): return gen_math_ops._sigmoid(x, name=name) +def log_sigmoid(x, name=None): + """Computes log sigmoid of `x` element-wise. + + Specifically, `y = log(1 / (1 + exp(-x)))`. For numerical stability, + we use `y = -tf.nn.softplus(-x)`. + + Args: + x: A Tensor with type `float32` or `float64`. + name: A name for the operation (optional). + + Returns: + A Tensor with the same type as `x`. + """ + with ops.name_scope(name, "LogSigmoid", [x]) as name: + x = ops.convert_to_tensor(x, name="x") + return gen_math_ops._neg(gen_nn_ops.softplus(-x), name=name) + + def tanh(x, name=None): """Computes hyperbolic tangent of `x` element-wise. diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index c5c5169231..7b6494e0c9 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -27,6 +27,7 @@ See the @{$python/nn} guide. @@dropout @@bias_add @@sigmoid +@@log_sigmoid @@tanh @@convolution @@conv2d diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 9d24eb242d..09e04d4247 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -145,6 +145,7 @@ _allowed_symbols_math_ops = [ # These are documented in nn. # We are are not importing nn because it would create a circular dependency. "sigmoid", + "log_sigmoid", "tanh", ] diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index b0e0924093..a8d3dd65c7 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -1168,6 +1168,10 @@ tf_module { name: "log1p" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "log_sigmoid" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "logical_and" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " -- GitLab From 59258116924451bfc8ed5eba9d3efa4d12df47f1 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 26 Apr 2017 10:12:14 -0800 Subject: [PATCH 007/697] GrapplerItem is a struct and not a class Change: 154319534 --- .../core/grappler/inputs/trivial_test_graph_input_yielder.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h index 4c5600c816..434b660614 100644 --- a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h +++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h @@ -24,7 +24,7 @@ namespace tensorflow { namespace grappler { class Cluster; -class GrapplerItem; +struct GrapplerItem; class TrivialTestGraphInputYielder : public InputYielder { public: -- GitLab From 320719d76483fba26fc92dfc3cda1736201efd65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dandelion=20Man=C3=A9?= Date: Wed, 26 Apr 2017 10:13:05 -0800 Subject: [PATCH 008/697] Migrate tf_storage and tf_globals to d3v4. Change: 154319656 --- .../components/tf_globals_d3v4/globals.ts | 38 ++ .../components/tf_storage_d3v4/storage.ts | 400 ++++++++++++++++++ .../tf_storage_d3v4/storageTests.ts | 64 +++ .../components/tf_storage_d3v4/tests.html | 30 ++ 4 files changed, 532 insertions(+) create mode 100644 tensorflow/tensorboard/components/tf_globals_d3v4/globals.ts create mode 100644 tensorflow/tensorboard/components/tf_storage_d3v4/storage.ts create mode 100644 tensorflow/tensorboard/components/tf_storage_d3v4/storageTests.ts create mode 100644 tensorflow/tensorboard/components/tf_storage_d3v4/tests.html diff --git a/tensorflow/tensorboard/components/tf_globals_d3v4/globals.ts b/tensorflow/tensorboard/components/tf_globals_d3v4/globals.ts new file mode 100644 index 0000000000..42c73708cd --- /dev/null +++ b/tensorflow/tensorboard/components/tf_globals_d3v4/globals.ts @@ -0,0 +1,38 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + + +// The names of TensorBoard tabs. +export const TABS = [ + 'scalars', 'images', 'audio', 'graphs', 'distributions', 'histograms', + 'embeddings', 'text' +]; + +// If true, TensorBoard stores its hash in the URI state. +// If false, tab switching in TensorBoard will not update location hash, +// because hash updates interfere with wct_tests. +export let USE_HASH = false; + +let _fakeHash = ''; + +export function setFakeHash(h: string) { + _fakeHash = h; +} + +export function getFakeHash() { + return _fakeHash; +} + diff --git a/tensorflow/tensorboard/components/tf_storage_d3v4/storage.ts b/tensorflow/tensorboard/components/tf_storage_d3v4/storage.ts new file mode 100644 index 0000000000..19a27b9cbd --- /dev/null +++ b/tensorflow/tensorboard/components/tf_storage_d3v4/storage.ts @@ -0,0 +1,400 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {USE_HASH, setFakeHash, getFakeHash, TABS} from '../tf_globals_d3v4/globals'; +import * as _ from 'lodash'; + + +/* tslint:disable:no-namespace variable-name */ +/** + * The Storage Module provides storage for URL parameters, and an API for + * getting and setting TensorBoard's stateful URI. + * + * It generates URI components like: events&runPrefix=train* + * which TensorBoard uses after like localhost:8000/#events&runPrefix=train* + * to store state in the URI. + * + * It also allows saving the values to localStorage for long-term persistance. + */ +type StringDict = {[key: string]: string}; + +/** + * A key that users cannot use, since TensorBoard uses this to store info + * about the active tab. + */ +export let TAB = '__tab__'; + +/** + * The name of the property for users to set on a Polymer component + * in order for its stored properties to be stored in the URI unambiguously. + * (No need to set this if you want mutliple instances of the component to + * share URI state) + * + * Example: + * + * + * The disambiguator should be set to any unique value so that multiple + * instances of the component can store properties in URI storage. + * + * Because it's hard to dereference this variable in HTML property bindings, + * it is NOT safe to change the disambiguator string without find+replace + * across the codebase. + */ +export let DISAMBIGUATOR = 'disambiguator'; + +/** + * Return a string stored in URI or localStorage. + * Undefined if not found. + */ +export function getString(key: string, useLocalStorage: boolean): string { + if (useLocalStorage) { + return window.localStorage.getItem(key); + } else { + return _componentToDict(_readComponent())[key]; + } +} + +/** + * Set a string in URI or localStorage. + */ +export function setString( + key: string, value: string, useLocalStorage: boolean) { + if (useLocalStorage) { + window.localStorage.setItem(key, value); + } else { + const items = _componentToDict(_readComponent()); + items[key] = value; + _writeComponent(_dictToComponent(items)); + } +} + +/** + * Return a boolean stored in stored in URI or localStorage. + * Undefined if not found. + */ +export function getBoolean(key: string, useLocalStorage: boolean): boolean { + const item = getString(key, useLocalStorage); + return item === 'true' ? true : item === 'false' ? false : undefined; +} + +/** + * Store a boolean in URI or localStorage. + */ +export function setBoolean( + key: string, value: boolean, useLocalStorage = false) { + setString(key, value.toString(), useLocalStorage); +} + +/** + * Return a number stored in stored in URI or localStorage. + * Undefined if not found. + */ +export function getNumber(key: string, useLocalStorage: boolean): number { + const item = getString(key, useLocalStorage); + return item === undefined ? undefined : +item; +} + +/** + * Store a number in URI or localStorage. + */ +export function setNumber( + key: string, value: number, useLocalStorage: boolean) { + setString(key, '' + value, useLocalStorage); +} + +/** + * Return an object stored in stored in URI or localStorage. + * Undefined if not found. + */ +export function getObject(key: string, useLocalStorage: boolean): {} { + const item = getString(key, useLocalStorage); + return item === undefined ? undefined : JSON.parse(atob(item)); +} + +/** + * Store an object in URI or localStorage. + */ +export function setObject(key: string, value: {}, useLocalStorage: boolean) { + setString(key, btoa(JSON.stringify(value)), useLocalStorage); +} + +/** + * Get a unique storage name for a (Polymer component, propertyName) tuple. + * + * DISAMBIGUATOR must be set on the component, if other components use the + * same propertyName. + */ +export function getURIStorageName( + component: {}, propertyName: string): string { + const d = component[DISAMBIGUATOR]; + const components = d == null ? [propertyName] : [d, propertyName]; + return components.join('.'); +} + +/** + * Return a function that: + * (1) Initializes a Polymer boolean property with a default value, if its + * value is not already set + * (2) Sets up listener that updates Polymer property on hash change. + */ +export function getBooleanInitializer( + propertyName: string, defaultVal: boolean, + useLocalStorage = false): Function { + return _getInitializer( + getBoolean, propertyName, defaultVal, useLocalStorage); +} + +/** + * Return a function that: + * (1) Initializes a Polymer string property with a default value, if its + * value is not already set + * (2) Sets up listener that updates Polymer property on hash change. + */ +export function getStringInitializer( + propertyName: string, defaultVal: string, + useLocalStorage = false): Function { + return _getInitializer( + getString, propertyName, defaultVal, useLocalStorage); +} + +/** + * Return a function that: + * (1) Initializes a Polymer number property with a default value, if its + * value is not already set + * (2) Sets up listener that updates Polymer property on hash change. + */ +export function getNumberInitializer( + propertyName: string, defaultVal: number, + useLocalStorage = false): Function { + return _getInitializer( + getNumber, propertyName, defaultVal, useLocalStorage); +} + +/** + * Return a function that: + * (1) Initializes a Polymer Object property with a default value, if its + * value is not already set + * (2) Sets up listener that updates Polymer property on hash change. + * + * Generates a deep clone of the defaultVal to avoid mutation issues. + */ +export function getObjectInitializer( + propertyName: string, defaultVal: {}, useLocalStorage = false): Function { + return _getInitializer( + getObject, propertyName, defaultVal, useLocalStorage); +} + +/** + * Return a function that updates URIStorage when a string property changes. + */ +export function getBooleanObserver( + propertyName: string, defaultVal: boolean, + useLocalStorage = false): Function { + return _getObserver( + getBoolean, setBoolean, propertyName, defaultVal, useLocalStorage); +} + +/** + * Return a function that updates URIStorage when a string property changes. + */ +export function getStringObserver( + propertyName: string, defaultVal: string, + useLocalStorage = false): Function { + return _getObserver( + getString, setString, propertyName, defaultVal, useLocalStorage); +} + +/** + * Return a function that updates URIStorage when a number property changes. + */ +export function getNumberObserver( + propertyName: string, defaultVal: number, + useLocalStorage = false): Function { + return _getObserver( + getNumber, setNumber, propertyName, defaultVal, useLocalStorage); +} + +/** + * Return a function that updates URIStorage when an object property changes. + * Generates a deep clone of the defaultVal to avoid mutation issues. + */ +export function getObjectObserver( + propertyName: string, defaultVal: {}, useLocalStorage = false): Function { + const clone = _.cloneDeep(defaultVal); + return _getObserver( + getObject, setObject, propertyName, clone, useLocalStorage); +} + +/** + * Read component from URI (e.g. returns "events&runPrefix=train*"). + */ +function _readComponent(): string { + return USE_HASH ? window.location.hash.slice(1) : getFakeHash(); +} + +/** + * Write component to URI. + */ +function _writeComponent(component: string) { + if (USE_HASH) { + window.location.hash = component; + } else { + setFakeHash(component); + } +} + +/** + * Convert dictionary of strings into a URI Component. + * All key value entries get added as key value pairs in the component, + * with the exception of a key with the TAB value, which if present + * gets prepended to the URI Component string for backwards comptability + * reasons. + */ +function _dictToComponent(items: StringDict): string { + let component = ''; + + // Add the tab name e.g. 'events', 'images', 'histograms' as a prefix + // for backwards compatbility. + if (items[TAB] !== undefined) { + component += items[TAB]; + } + + // Join other strings with &key=value notation + const nonTab = _.pairs(items) + .filter((pair) => pair[0] !== TAB) + .map((pair) => { + return encodeURIComponent(pair[0]) + '=' + + encodeURIComponent(pair[1]); + }) + .join('&'); + + return nonTab.length > 0 ? (component + '&' + nonTab) : component; +} + +/** + * Convert a URI Component into a dictionary of strings. + * Component should consist of key-value pairs joined by a delimiter + * with the exception of the tabName. + * Returns dict consisting of all key-value pairs and + * dict[TAB] = tabName + */ +function _componentToDict(component: string): StringDict { + const items = {} as StringDict; + + const tokens = component.split('&'); + tokens.forEach((token) => { + const kv = token.split('='); + // Special backwards compatibility for URI components like #events + if (kv.length === 1 && _.contains(TABS, kv[0])) { + items[TAB] = kv[0]; + } else if (kv.length === 2) { + items[decodeURIComponent(kv[0])] = decodeURIComponent(kv[1]); + } + }); + return items; +} + +/** + * Return a function that: + * (1) Initializes a Polymer property with a default value, if its + * value is not already set + * (2) Sets up listener that updates Polymer property on hash change. + */ +function _getInitializer( + get: (name: string, useLocalStorage: boolean) => T, propertyName: string, + defaultVal: T, useLocalStorage): Function { + return function() { + const URIStorageName = getURIStorageName(this, propertyName); + // setComponentValue will be called every time the hash changes, and is + // responsible for ensuring that new state in the hash will be propagated + // to the component with that property. + // It is important that this function does not re-assign needlessly, + // to avoid Polymer observer churn. + const setComponentValue = () => { + const uriValue = get(URIStorageName, false); + const currentValue = this[propertyName]; + // if uriValue is undefined, we will ensure that the property has the + // default value + if (uriValue === undefined) { + let valueToSet: T; + // if we are using localStorage, we will set the value to the value + // from localStorage. Then, the corresponding observer will proxy + // the localStorage value into URI storage. + // in this way, localStorage takes precedence over the default val + // but not over the URI value. + if (useLocalStorage) { + const useLocalStorageValue = get(URIStorageName, true); + valueToSet = useLocalStorageValue === undefined ? + defaultVal : + useLocalStorageValue; + } else { + valueToSet = defaultVal; + } + if (!_.isEqual(currentValue, valueToSet)) { + // If we don't have an explicit URI value, then we need to ensure + // the property value is equal to the default value. + // We will assign a clone rather than the canonical default, because + // the component receiving this property may mutate it, and we need + // to keep a pristine copy of the default. + this[propertyName] = _.clone(valueToSet); + } + // In this case, we have an explicit URI value, so we will ensure that + // the component has an equivalent value. + } else { + if (!_.isEqual(uriValue, currentValue)) { + this[propertyName] = uriValue; + } + } + }; + // Set the value on the property. + setComponentValue(); + // Update it when the hashchanges. + window.addEventListener('hashchange', setComponentValue); + }; +} + +/** + * Return a function that updates URIStorage when a property changes. + */ +function _getObserver( + get: (name: string, useLocalStorage: boolean) => T, + set: (name: string, newVal: T, useLocalStorage: boolean) => void, + propertyName: string, defaultVal: T, useLocalStorage: boolean): Function { + return function() { + const URIStorageName = getURIStorageName(this, propertyName); + const newVal = this[propertyName]; + // if this is a localStorage property, we always synchronize the value + // in localStorage to match the one currently in the URI. + if (useLocalStorage) { + set(URIStorageName, newVal, true); + } + if (!_.isEqual(newVal, get(URIStorageName, false))) { + if (_.isEqual(newVal, defaultVal)) { + _unsetFromURI(URIStorageName); + } else { + set(URIStorageName, newVal, false); + } + } + }; +} + +/** + * Delete a key from the URI. + */ +function _unsetFromURI(key) { + const items = _componentToDict(_readComponent()); + delete items[key]; + _writeComponent(_dictToComponent(items)); +} + diff --git a/tensorflow/tensorboard/components/tf_storage_d3v4/storageTests.ts b/tensorflow/tensorboard/components/tf_storage_d3v4/storageTests.ts new file mode 100644 index 0000000000..adc4dde716 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_storage_d3v4/storageTests.ts @@ -0,0 +1,64 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {TAB, getString, getNumber, getObject, setString, setNumber, setObject} from './storage'; +import {TABS} from '../tf_globals_d3v4/globals'; + +/* tslint:disable:no-namespace */ +describe('URIStorage', () => { + it('get/setString', () => { + setString('key_a', 'hello', false); + setString('key_b', 'there', false); + chai.assert.equal('hello', getString('key_a', false)); + chai.assert.equal('there', getString('key_b', false)); + chai.assert.equal(null, getString('key_c', false)); + }); + + it('get/setNumber', () => { + setNumber('key_a', 12, false); + setNumber('key_b', 3.4, false); + chai.assert.equal(12, getNumber('key_a', false)); + chai.assert.equal(3.4, getNumber('key_b', false)); + chai.assert.equal(null, getNumber('key_c', false)); + }); + + it('get/setObject', () => { + const obj = {'foo': 2.3, 'bar': 'barstr'}; + setObject('key_a', obj, false); + chai.assert.deepEqual(obj, getObject('key_a', false)); + }); + + it('get/setWeirdValues', () => { + setNumber('key_a', NaN, false); + chai.assert.deepEqual(NaN, getNumber('key_a', false)); + + setNumber('key_a', +Infinity, false); + chai.assert.equal(+Infinity, getNumber('key_a', false)); + + setNumber('key_a', -Infinity, false); + chai.assert.equal(-Infinity, getNumber('key_a', false)); + + setNumber('key_a', 1 / 3, false); + chai.assert.equal(1 / 3, getNumber('key_a', false)); + + setNumber('key_a', -0, false); + chai.assert.equal(-0, getNumber('key_a', false)); + }); + + it('set/getTab', () => { + setString(TAB, TABS[0], false); + chai.assert.equal(TABS[0], getString(TAB, false)); + }); +}); + diff --git a/tensorflow/tensorboard/components/tf_storage_d3v4/tests.html b/tensorflow/tensorboard/components/tf_storage_d3v4/tests.html new file mode 100644 index 0000000000..6d395b0702 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_storage_d3v4/tests.html @@ -0,0 +1,30 @@ + + + + + + + + + + + + + + + -- GitLab From ca5aa8a3c2ab96c7ebb050314c5a904f43a88046 Mon Sep 17 00:00:00 2001 From: Sergio Guadarrama Date: Wed, 26 Apr 2017 11:02:17 -0800 Subject: [PATCH 009/697] Allow SVStepCounter to measure progress for other counters beyond global_step. Change: 154326389 --- tensorflow/python/training/supervisor.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index 9435bdfa1c..93e64b4ab0 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -1001,28 +1001,32 @@ class SVSummaryThread(coordinator.LooperThread): class SVStepCounterThread(coordinator.LooperThread): """Threads to count steps and measure their duration.""" - def __init__(self, sv, sess): + def __init__(self, sv, sess, step_counter=None): """Create a `SVStepCounterThread`. Args: sv: A `Supervisor`. sess: A `Session`. + step_counter: A `Tensor` holding the step counter. By defaults, it uses + sv.global_step. """ super(SVStepCounterThread, self).__init__(sv.coord, sv.save_summaries_secs) self._sv = sv self._sess = sess self._last_time = 0.0 self._last_step = 0 - self._summary_tag = "%s/sec" % self._sv.global_step.op.name + step_counter = sv.global_step if step_counter is None else step_counter + self._step_counter = step_counter + self._summary_tag = "%s/sec" % self._step_counter.op.name def start_loop(self): self._last_time = time.time() self._last_step = training_util.global_step( - self._sess, self._sv.global_step) + self._sess, self._step_counter) def run_loop(self): # Count the steps. - current_step = training_util.global_step(self._sess, self._sv.global_step) + current_step = training_util.global_step(self._sess, self._step_counter) added_steps = current_step - self._last_step self._last_step = current_step # Measure the elapsed time. @@ -1030,7 +1034,10 @@ class SVStepCounterThread(coordinator.LooperThread): elapsed_time = current_time - self._last_time self._last_time = current_time # Reports the number of steps done per second - steps_per_sec = added_steps / elapsed_time if elapsed_time != 0. else float("inf") + if elapsed_time > 0.: + steps_per_sec = added_steps / elapsed_time + else: + steps_per_sec = float("inf") summary = Summary(value=[Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)]) if self._sv.summary_writer: -- GitLab From deca449d2419464e30ab4c7baf5b727f497aea8e Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 26 Apr 2017 11:44:11 -0800 Subject: [PATCH 010/697] Provide a user-defined default constructor to grappler::ItemConfig to avoid compiler error Change: 154331360 --- tensorflow/core/grappler/grappler_item_builder.h | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h index 7088636994..ed75749009 100644 --- a/tensorflow/core/grappler/grappler_item_builder.h +++ b/tensorflow/core/grappler/grappler_item_builder.h @@ -27,13 +27,18 @@ class MetaGraphDef; namespace grappler { struct ItemConfig { + ItemConfig() + : ignore_user_placement(true), + ignore_colocation(true), + placeholder_unknown_output_shape_dim(-1) {} + // If true, ignore all user specified node placement. - bool ignore_user_placement = true; + bool ignore_user_placement; // If true, ignore all user specified colocation attributes. - bool ignore_colocation = true; + bool ignore_colocation; // Dimension to use if a placeholder node has an _output_shapes attribute with // a dimension of -1. - int placeholder_unknown_output_shape_dim = -1; + int placeholder_unknown_output_shape_dim; }; // Factory method for creating a GrapplerItem from a MetaGraphDef. -- GitLab From aa1f99845dacba0f37f1b6fad5e51ce7688ee1c3 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Wed, 26 Apr 2017 12:02:40 -0800 Subject: [PATCH 011/697] [tf contrib seq2seq] Update BeamSearchDecoder + AttentionWrapper API: * BeamSearchDecoder no longer tries to pass a tiling operator to the cell. * Instead the cell must "know" about the fact that batch_size is now bigger. * Much more explicit and careful error messages if batch sizes don't match. * Added helper function to tile the batch by beam_width. * AttentionWrapper now accepts an initial_cell_state argument (it'll pass the value back through when zero_state() is called). Change: 154333600 --- tensorflow/contrib/seq2seq/__init__.py | 2 + .../kernel_tests/attention_wrapper_test.py | 600 +++++++++--------- .../kernel_tests/beam_search_decoder_test.py | 8 +- .../seq2seq/python/ops/attention_wrapper.py | 220 ++++--- .../seq2seq/python/ops/beam_search_decoder.py | 175 +++-- 5 files changed, 568 insertions(+), 437 deletions(-) diff --git a/tensorflow/contrib/seq2seq/__init__.py b/tensorflow/contrib/seq2seq/__init__.py index 064ce00c61..dd497197e3 100644 --- a/tensorflow/contrib/seq2seq/__init__.py +++ b/tensorflow/contrib/seq2seq/__init__.py @@ -44,6 +44,8 @@ See the @{$python/contrib.seq2seq} guide. @@AttentionWrapper @@gather_tree + +@@tile_batch """ from __future__ import absolute_import diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 4054b51412..40b50338ad 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -159,7 +159,7 @@ class AttentionWrapperTest(test.TestCase): sess_results["final_state"]) sys.stdout.flush() print("Copy/paste (%s)\nexpected_final_alignment_history = " % name, - sess_results["state_alignment_history"]) + np.asarray(sess_results["state_alignment_history"])) sys.stdout.flush() nest.map_structure(self.assertAllClose, expected_final_output, sess_results["final_outputs"]) @@ -177,41 +177,41 @@ class AttentionWrapperTest(test.TestCase): expected_final_output = BasicDecoderOutput( rnn_output=array( [[[ - 1.89980457e-03, 1.89681584e-03, 2.05339328e-03, -3.83376027e-03, - -4.31808922e-03, -6.45466987e-03 + 2.04633363e-03, 1.89259532e-03, 2.09550979e-03, -3.81628517e-03, + -4.36160620e-03, -6.43933658e-03 ], [ - 2.27232254e-03, 2.02509761e-03, 2.01666891e-03, -3.87230632e-03, - -3.47119337e-03, -6.15991233e-03 + 2.41885195e-03, 2.02089013e-03, 2.05879519e-03, -3.85483308e-03, + -3.51473060e-03, -6.14458136e-03 ], [ - 1.87640532e-03, 2.07374478e-03, 2.30582547e-03, -3.64564802e-03, - -3.75995948e-03, -6.28685066e-03 + 2.02294230e-03, 2.06955452e-03, 2.34797411e-03, -3.62816593e-03, + -3.80352931e-03, -6.27150526e-03 ]], [[ - 4.89835022e-03, -1.94158917e-03, 3.32316267e-03, - -2.82446202e-03, 3.63192149e-03, -4.80734091e-03 + 4.89025004e-03, -1.97221269e-03, 3.34283570e-03, + -2.79326970e-03, 3.63148772e-03, -4.79645561e-03 ], [ - 5.14256489e-03, -2.00877781e-03, 3.49807227e-03, - -2.86567654e-03, 3.14202951e-03, -5.32575324e-03 + 5.13446378e-03, -2.03941623e-03, 3.51774949e-03, + -2.83448119e-03, 3.14159272e-03, -5.31486655e-03 ], [ - 5.21511910e-03, -2.18198029e-03, 3.56219849e-03, - -2.88951304e-03, 3.20866983e-03, -5.21918852e-03 + 5.20701287e-03, -2.21262546e-03, 3.58187454e-03, + -2.85831164e-03, 3.20822699e-03, -5.20829484e-03 ]], [[ - -1.34951377e-03, -9.68646549e-04, -2.11444520e-03, - -1.85243192e-03, -5.27541339e-03, -9.10969637e-03 + -1.34046993e-03, -9.99792013e-04, -2.11631414e-03, + -1.85202830e-03, -5.26227616e-03, -9.08544939e-03 ], [ - -1.36390887e-03, -1.01293903e-03, -1.96592091e-03, - -1.80044665e-03, -5.62618347e-03, -9.36636236e-03 + -1.35486713e-03, -1.04408595e-03, -1.96779310e-03, + -1.80004584e-03, -5.61304903e-03, -9.34211537e-03 ], [ - -1.13357347e-03, -7.37126335e-04, -1.99582824e-03, - -1.88097963e-03, -5.03196474e-03, -9.34652984e-03 + -1.12452905e-03, -7.68281636e-04, -1.99770415e-03, + -1.88058324e-03, -5.01882844e-03, -9.32228006e-03 ]], [[ - 1.52963377e-03, -3.97205260e-03, -9.64675564e-04, - 8.51404853e-04, -1.29804458e-03, 6.56467676e-03 + 1.52967637e-03, -3.97213362e-03, -9.64699371e-04, + 8.51419638e-04, -1.29806029e-03, 6.56482670e-03 ], [ - 1.22557906e-03, -4.56343032e-03, -1.08188344e-03, - 8.27252632e-04, -2.10058759e-03, 6.43082103e-03 + 1.22562144e-03, -4.56351135e-03, -1.08190742e-03, + 8.27267300e-04, -2.10060296e-03, 6.43097097e-03 ], [ - 9.93478228e-04, -4.37378604e-03, -1.41531695e-03, - 6.44775166e-04, -2.16480484e-03, 6.68286439e-03 + 9.93521884e-04, -4.37386986e-03, -1.41534151e-03, + 6.44790183e-04, -2.16482091e-03, 6.68301852e-03 ]], [[ -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04, -1.56512906e-04, 9.63474595e-05 @@ -231,21 +231,21 @@ class AttentionWrapperTest(test.TestCase): cell_state=LSTMStateTuple( c=array( [[ - -2.18963176e-02, -8.04424379e-03, -1.48289464e-03, - 1.61068402e-02, -1.37983467e-02, -7.57976994e-03, - -8.28560349e-03, -1.18737305e-02, 1.78835373e-02 + -2.18977481e-02, -8.04181397e-03, -1.48273818e-03, + 1.61075518e-02, -1.37986457e-02, -7.57964421e-03, + -8.28644261e-03, -1.18742418e-02, 1.78838037e-02 ], [ - 1.74205080e-02, -1.41929444e-02, -3.88092734e-03, - 3.19708064e-02, -3.54689620e-02, -2.14698724e-02, - -6.21716119e-03, -1.69295724e-03, -1.94495302e-02 + 1.74201727e-02, -1.41931782e-02, -3.88098788e-03, + 3.19711640e-02, -3.54694054e-02, -2.14694049e-02, + -6.21706853e-03, -1.69323490e-03, -1.94494929e-02 ], [ - -1.14528481e-02, 8.77819210e-03, -1.62970200e-02, - -1.39963552e-02, 1.34831406e-02, -1.04494914e-02, - 6.16127765e-03, -9.41022579e-03, -6.57590060e-03 + -1.14532551e-02, 8.77828151e-03, -1.62972715e-02, + -1.39963031e-02, 1.34832524e-02, -1.04488730e-02, + 6.16201758e-03, -9.41041857e-03, -6.57599326e-03 ], [ -4.74753827e-02, -1.19123599e-02, -7.40140676e-05, - 4.10552323e-02, -1.36711076e-03, 2.11795457e-02, - -2.80460119e-02, -5.44509329e-02, -2.91906092e-02 + 4.10552323e-02, -1.36711076e-03, 2.11795494e-02, + -2.80460101e-02, -5.44509329e-02, -2.91906092e-02 ], [ 2.25644894e-02, -1.40382675e-03, 1.92396250e-02, 5.49034867e-03, -1.27930511e-02, -3.15603940e-03, @@ -254,20 +254,20 @@ class AttentionWrapperTest(test.TestCase): dtype=float32), h=array( [[ - -1.09840557e-02, -3.97477299e-03, -7.54582870e-04, - 7.91188516e-03, -7.02184858e-03, -3.80711886e-03, - -4.22059745e-03, -6.05464494e-03, 8.92061181e-03 + -1.09847616e-02, -3.97357112e-03, -7.54502777e-04, + 7.91223347e-03, -7.02199014e-03, -3.80705344e-03, + -4.22102772e-03, -6.05491130e-03, 8.92073940e-03 ], [ - 8.68131686e-03, -7.16938032e-03, -1.88384682e-03, - 1.62678920e-02, -1.76827926e-02, -1.06622791e-02, - -3.07528162e-03, -8.45885137e-04, -9.99388192e-03 + 8.68115202e-03, -7.16950046e-03, -1.88387593e-03, + 1.62680726e-02, -1.76830068e-02, -1.06620435e-02, + -3.07523785e-03, -8.46023730e-04, -9.99386702e-03 ], [ - -5.71205560e-03, 4.50050412e-03, -8.07640795e-03, - -6.94844872e-03, 6.75682165e-03, -5.12113515e-03, - 3.06208082e-03, -4.61743120e-03, -3.23931244e-03 + -5.71225956e-03, 4.50055022e-03, -8.07653368e-03, + -6.94842264e-03, 6.75687613e-03, -5.12083014e-03, + 3.06244940e-03, -4.61752573e-03, -3.23935854e-03 ], [ -2.37231534e-02, -5.88526297e-03, -3.72226204e-05, - 2.01789513e-02, -6.75848918e-04, 1.06686354e-02, + 2.01789513e-02, -6.75848918e-04, 1.06686372e-02, -1.42624676e-02, -2.69628745e-02, -1.45034352e-02 ], [ 1.12585640e-02, -6.92534202e-04, 9.88917705e-03, @@ -277,17 +277,17 @@ class AttentionWrapperTest(test.TestCase): dtype=float32)), attention=array( [[ - 0.00187641, 0.00207374, 0.00230583, -0.00364565, -0.00375996, - -0.00628685 + 0.00202294, 0.00206955, 0.00234797, -0.00362817, -0.00380353, + -0.00627151 ], [ - 0.00521512, -0.00218198, 0.0035622, -0.00288951, 0.00320867, - -0.00521919 + 0.00520701, -0.00221263, 0.00358187, -0.00285831, 0.00320823, + -0.00520829 ], [ - -0.00113357, -0.00073713, -0.00199583, -0.00188098, -0.00503196, - -0.00934653 + -0.00112453, -0.00076828, -0.0019977, -0.00188058, -0.00501883, + -0.00932228 ], [ - 0.00099348, -0.00437379, -0.00141532, 0.00064478, -0.0021648, - 0.00668286 + 0.00099352, -0.00437387, -0.00141534, 0.00064479, -0.00216482, + 0.00668302 ], [ 0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734, -0.00026981 @@ -297,41 +297,41 @@ class AttentionWrapperTest(test.TestCase): alignment_history=()) expected_final_alignment_history = [[[ - 0.12525459, 0.12438694, 0.12613983, 0.12484372, 0.12484372, 0.12484372, - 0.12484372, 0.12484372 + 0.12586178, 0.12272788, 0.1271652, 0.12484902, 0.12484902, 0.12484902, + 0.12484902, 0.12484902 ], [ - 0.12648369, 0.12322279, 0.12504892, 0.12504892, 0.12504892, 0.12504892, - 0.12504892, 0.12504892 + 0.12612638, 0.12516938, 0.12478404, 0.12478404, 0.12478404, 0.12478404, + 0.12478404, 0.12478404 ], [ - 0.12611018, 0.12528601, 0.12638952, 0.12444285, 0.12444285, 0.12444285, - 0.12444285, 0.12444285 + 0.12595113, 0.12515794, 0.1255464, 0.1246689, 0.1246689, 0.1246689, + 0.1246689, 0.1246689 ], [ - 0.12492625, 0.12501054, 0.12501054, 0.12501054, 0.12501054, 0.12501054, - 0.12501054, 0.12501054 + 0.12492912, 0.12501013, 0.12501013, 0.12501013, 0.12501013, 0.12501013, + 0.12501013, 0.12501013 ], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]], [[ - 0.12525459, 0.12438691, 0.1261397, 0.12484375, 0.12484375, 0.12484375, - 0.12484375, 0.12484375 + 0.12586173, 0.12272781, 0.12716517, 0.12484905, 0.12484905, 0.12484905, + 0.12484905, 0.12484905 ], [ - 0.12648349, 0.12322238, 0.12504902, 0.12504902, 0.12504902, 0.12504902, - 0.12504902, 0.12504902 + 0.12612617, 0.1251694, 0.12478408, 0.12478408, 0.12478408, 0.12478408, + 0.12478408, 0.12478408 ], [ - 0.12611009, 0.12528586, 0.12638941, 0.12444293, 0.12444293, 0.12444293, - 0.12444293, 0.12444293 + 0.12595108, 0.12515777, 0.1255464, 0.12466895, 0.12466895, 0.12466895, + 0.12466895, 0.12466895 ], [ - 0.12492625, 0.12501054, 0.12501054, 0.12501054, 0.12501054, 0.12501054, - 0.12501054, 0.12501054 + 0.12492914, 0.12501012, 0.12501012, 0.12501012, 0.12501012, 0.12501012, + 0.12501012, 0.12501012 ], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]], [[ - 0.12525487, 0.12438726, 0.12613991, 0.12484358, 0.12484358, 0.12484358, - 0.12484358, 0.12484358 + 0.12586181, 0.12272815, 0.12716556, 0.12484891, 0.12484891, 0.12484891, + 0.12484891, 0.12484891 ], [ - 0.12648354, 0.12322233, 0.12504901, 0.12504901, 0.12504901, 0.12504901, - 0.12504901, 0.12504901 + 0.12612608, 0.12516941, 0.12478409, 0.12478409, 0.12478409, 0.12478409, + 0.12478409, 0.12478409 ], [ - 0.12611021, 0.125286, 0.12638955, 0.12444286, 0.12444286, 0.12444286, - 0.12444286, 0.12444286 + 0.12595116, 0.12515792, 0.12554643, 0.1246689, 0.1246689, 0.1246689, + 0.1246689, 0.1246689 ], [ - 0.12492625, 0.12501054, 0.12501054, 0.12501054, 0.12501054, 0.12501054, - 0.12501054, 0.12501054 + 0.1249292, 0.12501012, 0.12501012, 0.12501012, 0.12501012, 0.12501012, + 0.12501012, 0.12501012 ], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]]] self._testWithAttention( @@ -349,41 +349,41 @@ class AttentionWrapperTest(test.TestCase): expected_final_output = BasicDecoderOutput( rnn_output=array( [[[ - 6.64783875e-03, 2.94425711e-03, 5.26542449e-03, -2.64955591e-03, - -7.95925129e-03, -5.02286293e-03 + 1.27064800e-02, 3.57783446e-03, 8.22613202e-03, -1.61504047e-03, + -1.12555185e-02, -3.92740499e-03 ], [ - 7.01954123e-03, 3.07301106e-03, 5.22849336e-03, -2.68844375e-03, - -7.11239874e-03, -4.72904276e-03 + 1.30781950e-02, 3.70747922e-03, 8.18992872e-03, -1.65389013e-03, + -1.04098395e-02, -3.63383139e-03 ], [ - 6.62360899e-03, 3.12234787e-03, 5.51807694e-03, -2.46222341e-03, - -7.40198931e-03, -4.85701021e-03 + 1.26833543e-02, 3.75790196e-03, 8.48123431e-03, -1.42690970e-03, + -1.07016256e-02, -3.76088684e-03 ]], [[ - 7.37589924e-03, -1.02620223e-03, 3.61374952e-03, - -5.74620720e-03, 5.05625410e-03, -7.45209027e-03 + 6.88417302e-03, -2.04071682e-03, 4.17768257e-03, + -4.51408979e-03, 4.90086433e-03, -6.85973791e-03 ], [ - 7.61946291e-03, -1.09287468e-03, 3.78817180e-03, - -5.78709645e-03, 4.56611114e-03, -7.96987582e-03 + 7.12782983e-03, -2.10783770e-03, 4.35227761e-03, + -4.55496181e-03, 4.41066315e-03, -7.37757795e-03 ], [ - 7.69207766e-03, -1.26582675e-03, 3.85218812e-03, - -5.81111759e-03, 4.63287206e-03, -7.86337163e-03 + 7.20011396e-03, -2.28102156e-03, 4.41620918e-03, + -4.57867794e-03, 4.47713351e-03, -7.27072079e-03 ]], [[ - -2.69413739e-03, 3.47183552e-04, -1.82145904e-03, - -1.39805069e-03, -8.05486552e-03, -1.08372131e-02 + -2.20676698e-03, -1.43745833e-03, -1.99429039e-03, + -1.44722988e-03, -7.45461835e-03, -9.80243273e-03 ], [ - -2.70848931e-03, 3.03293345e-04, -1.67230750e-03, - -1.34555507e-03, -8.40565283e-03, -1.10935047e-02 + -2.22120387e-03, -1.48139545e-03, -1.84528576e-03, + -1.39490096e-03, -7.80559657e-03, -1.00586927e-02 ], [ - -2.47822329e-03, 5.79408603e-04, -1.70188327e-03, - -1.42583530e-03, -7.81180616e-03, -1.10740755e-02 + -1.99079141e-03, -1.20571791e-03, -1.87507609e-03, + -1.47541985e-03, -7.21158786e-03, -1.00391749e-02 ]], [[ - 1.48582947e-03, -3.88786104e-03, -9.39912978e-04, - 8.36255029e-04, -1.28223014e-03, 6.40908210e-03 + 1.48755650e-03, -3.89118027e-03, -9.40889120e-04, + 8.36852356e-04, -1.28285377e-03, 6.41521579e-03 ], [ - 1.18177081e-03, -4.47923271e-03, -1.05711201e-03, - 8.12121783e-04, -2.08477327e-03, 6.27523474e-03 + 1.18351437e-03, -4.48258361e-03, -1.05809816e-03, + 8.12723883e-04, -2.08540238e-03, 6.28142804e-03 ], [ - 9.49664740e-04, -4.28957958e-03, -1.39053771e-03, - 6.29657647e-04, -2.14899099e-03, 6.52727811e-03 + 9.51444614e-04, -4.29300033e-03, -1.39154412e-03, + 6.30271854e-04, -2.14963360e-03, 6.53359853e-03 ]], [[ -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04, -1.56512906e-04, 9.63474595e-05 @@ -396,28 +396,28 @@ class AttentionWrapperTest(test.TestCase): ]]], dtype=float32), sample_id=array( - [[0, 0, 0], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]], + [[0, 0, 0], [0, 0, 0], [1, 3, 1], [5, 5, 5], [3, 3, 2]], dtype=int32)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=array( [[ - -2.19389871e-02, -7.93421268e-03, -1.45148858e-03, - 1.61569901e-02, -1.38310911e-02, -7.59426132e-03, - -8.35836027e-03, -1.18763093e-02, 1.78797375e-02 + -2.19953191e-02, -7.81358499e-03, -1.42740645e-03, + 1.62037201e-02, -1.38600282e-02, -7.60386931e-03, + -8.42390209e-03, -1.18884994e-02, 1.78821683e-02 ], [ - 1.74194798e-02, -1.41677596e-02, -3.89095861e-03, - 3.19508761e-02, -3.54519747e-02, -2.15105712e-02, - -6.20894879e-03, -1.72719418e-03, -1.94605980e-02 + 1.74096227e-02, -1.41773149e-02, -3.89175024e-03, + 3.19635086e-02, -3.54669318e-02, -2.14924756e-02, + -6.20695669e-03, -1.73213519e-03, -1.94583312e-02 ], [ - -1.14357909e-02, 8.76635592e-03, -1.62690803e-02, - -1.39883338e-02, 1.34323873e-02, -1.04959216e-02, - 6.09614328e-03, -9.38197412e-03, -6.57159975e-03 + -1.14590004e-02, 8.76899902e-03, -1.62825100e-02, + -1.39863417e-02, 1.34333782e-02, -1.04652103e-02, + 6.13503950e-03, -9.39247012e-03, -6.57595927e-03 ], [ - -4.74738739e-02, -1.19136795e-02, -7.36564398e-05, - 4.10547666e-02, -1.36771239e-03, 2.11771261e-02, - -2.80481018e-02, -5.44515178e-02, -2.91903559e-02 + -4.74739373e-02, -1.19136302e-02, -7.36713409e-05, + 4.10547927e-02, -1.36768632e-03, 2.11772211e-02, + -2.80480143e-02, -5.44514954e-02, -2.91903671e-02 ], [ 2.25644894e-02, -1.40382675e-03, 1.92396250e-02, 5.49034867e-03, -1.27930511e-02, -3.15603940e-03, @@ -426,21 +426,21 @@ class AttentionWrapperTest(test.TestCase): dtype=float32), h=array( [[ - -1.10049099e-02, -3.92028037e-03, -7.38571223e-04, - 7.93652050e-03, -7.03821564e-03, -3.81436548e-03, - -4.25778655e-03, -6.05606195e-03, 8.91851448e-03 + -1.10325804e-02, -3.86056723e-03, -7.26287195e-04, + 7.95945339e-03, -7.05253659e-03, -3.81913339e-03, + -4.29130904e-03, -6.06246945e-03, 8.91948957e-03 ], [ - 8.68070032e-03, -7.15647917e-03, -1.88874488e-03, - 1.62575077e-02, -1.76745858e-02, -1.06826536e-02, - -3.07105901e-03, -8.63034453e-04, -9.99918394e-03 + 8.67583323e-03, -7.16136536e-03, -1.88911252e-03, + 1.62639488e-02, -1.76817775e-02, -1.06735229e-02, + -3.07015004e-03, -8.65494134e-04, -9.99815390e-03 ], [ - -5.70359221e-03, 4.49446775e-03, -8.06238409e-03, - -6.94446685e-03, 6.73149945e-03, -5.14409645e-03, - 3.02969781e-03, -4.60351165e-03, -3.23720207e-03 + -5.71519835e-03, 4.49585915e-03, -8.06909613e-03, + -6.94347266e-03, 6.73189852e-03, -5.12895826e-03, + 3.04909074e-03, -4.60868096e-03, -3.23936995e-03 ], [ - -2.37224046e-02, -5.88591257e-03, -3.70427515e-05, - 2.01787166e-02, -6.76146999e-04, 1.06674293e-02, - -1.42635051e-02, -2.69631781e-02, -1.45033030e-02 + -2.37224363e-02, -5.88588836e-03, -3.70502457e-05, + 2.01787297e-02, -6.76134136e-04, 1.06674768e-02, + -1.42634623e-02, -2.69631669e-02, -1.45033086e-02 ], [ 1.12585640e-02, -6.92534202e-04, 9.88917705e-03, 2.75237625e-03, -6.56115822e-03, -1.57997780e-03, @@ -449,17 +449,17 @@ class AttentionWrapperTest(test.TestCase): dtype=float32)), attention=array( [[ - 0.00662361, 0.00312235, 0.00551808, -0.00246222, -0.00740199, - -0.00485701 + 0.01268335, 0.0037579, 0.00848123, -0.00142691, -0.01070163, + -0.00376089 ], [ - 0.00769208, -0.00126583, 0.00385219, -0.00581112, 0.00463287, - -0.00786337 + 0.00720011, -0.00228102, 0.00441621, -0.00457868, 0.00447713, + -0.00727072 ], [ - -0.00247822, 0.00057941, -0.00170188, -0.00142584, -0.00781181, - -0.01107408 + -0.00199079, -0.00120572, -0.00187508, -0.00147542, -0.00721159, + -0.01003917 ], [ - 0.00094966, -0.00428958, -0.00139054, 0.00062966, -0.00214899, - 0.00652728 + 0.00095144, -0.004293, -0.00139154, 0.00063027, -0.00214963, + 0.0065336 ], [ 0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734, -0.00026981 @@ -480,41 +480,41 @@ class AttentionWrapperTest(test.TestCase): expected_final_output = BasicDecoderOutput( rnn_output=array( [[[ - 1.74749165e-03, 1.95862399e-03, 2.12293095e-03, -3.75889172e-03, - -4.39571124e-03, -6.32379763e-03 + 1.74922391e-03, 1.85935036e-03, 1.90880906e-03, -3.96941090e-03, + -4.17229906e-03, -6.65769773e-03 ], [ - 2.33045570e-03, 1.99094601e-03, 1.98377599e-03, -3.87950847e-03, - -3.42792575e-03, -6.17497414e-03 + 1.99638237e-03, 1.91135216e-03, 1.73234346e-03, -4.00905171e-03, + -3.15058464e-03, -6.34974428e-03 ], [ - 1.65032526e-03, 1.96972815e-03, 2.03462853e-03, -3.82007333e-03, - -3.46369296e-03, -6.54224353e-03 + 2.08854163e-03, 2.13832827e-03, 2.49780947e-03, -3.52849509e-03, + -3.96897132e-03, -6.12034509e-03 ]], [[ - 4.77780215e-03, -1.98677275e-03, 3.30950436e-03, - -2.68179504e-03, 3.56271653e-03, -4.67860466e-03 + 4.76492243e-03, -1.97180966e-03, 3.29327444e-03, + -2.68205139e-03, 3.55229783e-03, -4.66645230e-03 ], [ - 5.13039157e-03, -2.02797214e-03, 3.50760575e-03, - -2.83981953e-03, 3.13726603e-03, -5.31156827e-03 + 5.24956919e-03, -2.00631656e-03, 3.53828911e-03, + -2.96283513e-03, 3.20920302e-03, -5.43697737e-03 ], [ - 5.17205056e-03, -2.16446724e-03, 3.53219034e-03, - -2.86490913e-03, 3.17879021e-03, -5.17592067e-03 + 5.30424621e-03, -2.17913301e-03, 3.59509978e-03, + -2.97106663e-03, 3.26450402e-03, -5.31189423e-03 ]], [[ - -1.38538703e-03, -6.40910701e-04, -2.02864106e-03, - -1.79018872e-03, -5.18789608e-03, -8.95875692e-03 + -1.36440888e-03, -9.75572329e-04, -2.11284542e-03, + -1.84616144e-03, -5.31351101e-03, -9.12462734e-03 ], [ - -1.38620089e-03, -7.92010222e-04, -1.91070826e-03, - -1.76206254e-03, -5.56525169e-03, -9.27332044e-03 + -1.41863467e-03, -1.11081311e-03, -1.94056751e-03, + -1.74311269e-03, -5.76282106e-03, -9.29267984e-03 ], [ - -1.11966045e-03, -6.07630936e-04, -1.96643686e-03, - -1.86803937e-03, -4.93048411e-03, -9.25842486e-03 + -1.12129003e-03, -8.15156149e-04, -2.01535341e-03, + -1.89556007e-03, -5.04226238e-03, -9.37188603e-03 ]], [[ - 1.50820788e-03, -3.93087184e-03, -9.52563598e-04, - 8.43994785e-04, -1.29030924e-03, 6.48857141e-03 + 1.55163277e-03, -4.01433324e-03, -9.77111282e-04, + 8.59013060e-04, -1.30598655e-03, 6.64281659e-03 ], [ - 1.17029145e-03, -4.45716921e-03, -1.05062663e-03, - 8.08141369e-04, -2.08062865e-03, 6.23444980e-03 + 1.26811734e-03, -4.64518648e-03, -1.10593368e-03, + 8.41954607e-04, -2.11594440e-03, 6.58190623e-03 ], [ - 9.67921398e-04, -4.32466762e-03, -1.40085898e-03, - 6.35969569e-04, -2.15558149e-03, 6.59212377e-03 + 1.02682540e-03, -4.43787826e-03, -1.43417739e-03, + 6.56281307e-04, -2.17684195e-03, 6.80128345e-03 ]], [[ -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04, -1.56512906e-04, 9.63474595e-05 @@ -534,21 +534,21 @@ class AttentionWrapperTest(test.TestCase): cell_state=LSTMStateTuple( c=array( [[ - -2.18960866e-02, -8.04429129e-03, -1.48267671e-03, - 1.61071159e-02, -1.37981661e-02, -7.57933082e-03, - -8.28570686e-03, -1.18733812e-02, 1.78834442e-02 + -2.18942575e-02, -8.05099495e-03, -1.48526859e-03, + 1.61030665e-02, -1.37967104e-02, -7.57982396e-03, + -8.28088820e-03, -1.18743815e-02, 1.78839806e-02 ], [ - 1.74204130e-02, -1.41935758e-02, -3.88074201e-03, - 3.19713727e-02, -3.54694910e-02, -2.14688145e-02, - -6.21731905e-03, -1.69229065e-03, -1.94492843e-02 + 1.74203254e-02, -1.41929490e-02, -3.88103351e-03, + 3.19709182e-02, -3.54691371e-02, -2.14697979e-02, + -6.21709181e-03, -1.69324467e-03, -1.94495786e-02 ], [ - -1.14494488e-02, 8.77974741e-03, -1.62960067e-02, - -1.39961652e-02, 1.34879015e-02, -1.04502086e-02, - 6.15879148e-03, -9.40956455e-03, -6.57592434e-03 + -1.14536462e-02, 8.77809525e-03, -1.62965059e-02, + -1.39955431e-02, 1.34810507e-02, -1.04491040e-02, + 6.16097450e-03, -9.40943789e-03, -6.57613343e-03 ], [ - -4.74739634e-02, -1.19136050e-02, -7.36759976e-05, - 4.10547927e-02, -1.36767328e-03, 2.11772677e-02, - -2.80479677e-02, -5.44514805e-02, -2.91903690e-02 + -4.74765450e-02, -1.19113335e-02, -7.42897391e-05, + 4.10555862e-02, -1.36665069e-03, 2.11814232e-02, + -2.80444007e-02, -5.44504896e-02, -2.91908123e-02 ], [ 2.25644894e-02, -1.40382675e-03, 1.92396250e-02, 5.49034867e-03, -1.27930511e-02, -3.15603940e-03, @@ -557,21 +557,21 @@ class AttentionWrapperTest(test.TestCase): dtype=float32), h=array( [[ - -1.09839402e-02, -3.97479767e-03, -7.54472159e-04, - 7.91201927e-03, -7.02175125e-03, -3.80689627e-03, - -4.22065007e-03, -6.05447078e-03, 8.92056432e-03 + -1.09830676e-02, -3.97811923e-03, -7.55793473e-04, + 7.91002903e-03, -7.02103321e-03, -3.80714820e-03, + -4.21818346e-03, -6.05497835e-03, 8.92084371e-03 ], [ - 8.68127123e-03, -7.16970162e-03, -1.88375649e-03, - 1.62681788e-02, -1.76830534e-02, -1.06617520e-02, - -3.07536125e-03, -8.45551898e-04, -9.99375992e-03 + 8.68122280e-03, -7.16937613e-03, -1.88389909e-03, + 1.62679367e-02, -1.76828820e-02, -1.06622437e-02, + -3.07524228e-03, -8.46030540e-04, -9.99389403e-03 ], [ - -5.71034756e-03, 4.50129062e-03, -8.07590690e-03, - -6.94835978e-03, 6.75921654e-03, -5.12148207e-03, - 3.06083867e-03, -4.61710012e-03, -3.23932176e-03 + -5.71245840e-03, 4.50045895e-03, -8.07614625e-03, + -6.94804778e-03, 6.75577158e-03, -5.12094703e-03, + 3.06193763e-03, -4.61703911e-03, -3.23943049e-03 ], [ - -2.37224493e-02, -5.88587578e-03, -3.70525813e-05, - 2.01787278e-02, -6.76127791e-04, 1.06675029e-02, - -1.42634306e-02, -2.69631632e-02, -1.45033058e-02 + -2.37237271e-02, -5.88475820e-03, -3.73612711e-05, + 2.01791357e-02, -6.75620860e-04, 1.06695695e-02, + -1.42616741e-02, -2.69626491e-02, -1.45035451e-02 ], [ 1.12585640e-02, -6.92534202e-04, 9.88917705e-03, 2.75237625e-03, -6.56115822e-03, -1.57997780e-03, @@ -580,17 +580,17 @@ class AttentionWrapperTest(test.TestCase): dtype=float32)), attention=array( [[ - 0.00165033, 0.00196973, 0.00203463, -0.00382007, -0.00346369, - -0.00654224 + 0.00208854, 0.00213833, 0.00249781, -0.0035285, -0.00396897, + -0.00612035 ], [ - 0.00517205, -0.00216447, 0.00353219, -0.00286491, 0.00317879, - -0.00517592 + 0.00530425, -0.00217913, 0.0035951, -0.00297107, 0.0032645, + -0.00531189 ], [ - -0.00111966, -0.00060763, -0.00196644, -0.00186804, -0.00493048, - -0.00925842 + -0.00112129, -0.00081516, -0.00201535, -0.00189556, -0.00504226, + -0.00937189 ], [ - 0.00096792, -0.00432467, -0.00140086, 0.00063597, -0.00215558, - 0.00659212 + 0.00102683, -0.00443788, -0.00143418, 0.00065628, -0.00217684, + 0.00680128 ], [ 0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734, -0.00026981 @@ -613,41 +613,41 @@ class AttentionWrapperTest(test.TestCase): expected_final_output = BasicDecoderOutput( rnn_output=array( [[[ - 1.74749165e-03, 1.95862399e-03, 2.12293095e-03, -3.75889172e-03, - -4.39571124e-03, -6.32379763e-03 + 1.74922391e-03, 1.85935036e-03, 1.90880906e-03, -3.96941090e-03, + -4.17229906e-03, -6.65769773e-03 ], [ - 2.33045570e-03, 1.99094601e-03, 1.98377599e-03, -3.87950847e-03, - -3.42792575e-03, -6.17497414e-03 + 1.99638237e-03, 1.91135216e-03, 1.73234346e-03, -4.00905171e-03, + -3.15058464e-03, -6.34974428e-03 ], [ - 1.65032526e-03, 1.96972815e-03, 2.03462853e-03, -3.82007333e-03, - -3.46369296e-03, -6.54224353e-03 + 2.08854163e-03, 2.13832827e-03, 2.49780947e-03, -3.52849509e-03, + -3.96897132e-03, -6.12034509e-03 ]], [[ - 4.77780215e-03, -1.98677275e-03, 3.30950436e-03, - -2.68179504e-03, 3.56271653e-03, -4.67860466e-03 + 4.76492243e-03, -1.97180966e-03, 3.29327444e-03, + -2.68205139e-03, 3.55229783e-03, -4.66645230e-03 ], [ - 5.13039157e-03, -2.02797214e-03, 3.50760575e-03, - -2.83981953e-03, 3.13726603e-03, -5.31156827e-03 + 5.24956919e-03, -2.00631656e-03, 3.53828911e-03, + -2.96283513e-03, 3.20920302e-03, -5.43697737e-03 ], [ - 5.17205056e-03, -2.16446724e-03, 3.53219034e-03, - -2.86490913e-03, 3.17879021e-03, -5.17592067e-03 + 5.30424621e-03, -2.17913301e-03, 3.59509978e-03, + -2.97106663e-03, 3.26450402e-03, -5.31189423e-03 ]], [[ - -1.38538703e-03, -6.40910701e-04, -2.02864106e-03, - -1.79018872e-03, -5.18789608e-03, -8.95875692e-03 + -1.36440888e-03, -9.75572329e-04, -2.11284542e-03, + -1.84616144e-03, -5.31351101e-03, -9.12462734e-03 ], [ - -1.38620089e-03, -7.92010222e-04, -1.91070826e-03, - -1.76206254e-03, -5.56525169e-03, -9.27332044e-03 + -1.41863467e-03, -1.11081311e-03, -1.94056751e-03, + -1.74311269e-03, -5.76282106e-03, -9.29267984e-03 ], [ - -1.11966045e-03, -6.07630936e-04, -1.96643686e-03, - -1.86803937e-03, -4.93048411e-03, -9.25842486e-03 + -1.12129003e-03, -8.15156149e-04, -2.01535341e-03, + -1.89556007e-03, -5.04226238e-03, -9.37188603e-03 ]], [[ - 1.50820788e-03, -3.93087184e-03, -9.52563598e-04, - 8.43994785e-04, -1.29030924e-03, 6.48857141e-03 + 1.55163277e-03, -4.01433324e-03, -9.77111282e-04, + 8.59013060e-04, -1.30598655e-03, 6.64281659e-03 ], [ - 1.17029145e-03, -4.45716921e-03, -1.05062663e-03, - 8.08141369e-04, -2.08062865e-03, 6.23444980e-03 + 1.26811734e-03, -4.64518648e-03, -1.10593368e-03, + 8.41954607e-04, -2.11594440e-03, 6.58190623e-03 ], [ - 9.67921398e-04, -4.32466762e-03, -1.40085898e-03, - 6.35969569e-04, -2.15558149e-03, 6.59212377e-03 + 1.02682540e-03, -4.43787826e-03, -1.43417739e-03, + 6.56281307e-04, -2.17684195e-03, 6.80128345e-03 ]], [[ -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04, -1.56512906e-04, 9.63474595e-05 @@ -667,21 +667,21 @@ class AttentionWrapperTest(test.TestCase): cell_state=LSTMStateTuple( c=array( [[ - -2.18960866e-02, -8.04429129e-03, -1.48267671e-03, - 1.61071159e-02, -1.37981661e-02, -7.57933082e-03, - -8.28570686e-03, -1.18733812e-02, 1.78834442e-02 + -2.18942575e-02, -8.05099495e-03, -1.48526859e-03, + 1.61030665e-02, -1.37967104e-02, -7.57982396e-03, + -8.28088820e-03, -1.18743815e-02, 1.78839806e-02 ], [ - 1.74204130e-02, -1.41935758e-02, -3.88074201e-03, - 3.19713727e-02, -3.54694910e-02, -2.14688145e-02, - -6.21731905e-03, -1.69229065e-03, -1.94492843e-02 + 1.74203254e-02, -1.41929490e-02, -3.88103351e-03, + 3.19709182e-02, -3.54691371e-02, -2.14697979e-02, + -6.21709181e-03, -1.69324467e-03, -1.94495786e-02 ], [ - -1.14494488e-02, 8.77974741e-03, -1.62960067e-02, - -1.39961652e-02, 1.34879015e-02, -1.04502086e-02, - 6.15879148e-03, -9.40956455e-03, -6.57592434e-03 + -1.14536462e-02, 8.77809525e-03, -1.62965059e-02, + -1.39955431e-02, 1.34810507e-02, -1.04491040e-02, + 6.16097450e-03, -9.40943789e-03, -6.57613343e-03 ], [ - -4.74739634e-02, -1.19136050e-02, -7.36759976e-05, - 4.10547927e-02, -1.36767328e-03, 2.11772677e-02, - -2.80479677e-02, -5.44514805e-02, -2.91903690e-02 + -4.74765450e-02, -1.19113335e-02, -7.42897391e-05, + 4.10555862e-02, -1.36665069e-03, 2.11814232e-02, + -2.80444007e-02, -5.44504896e-02, -2.91908123e-02 ], [ 2.25644894e-02, -1.40382675e-03, 1.92396250e-02, 5.49034867e-03, -1.27930511e-02, -3.15603940e-03, @@ -690,21 +690,21 @@ class AttentionWrapperTest(test.TestCase): dtype=float32), h=array( [[ - -1.09839402e-02, -3.97479767e-03, -7.54472159e-04, - 7.91201927e-03, -7.02175125e-03, -3.80689627e-03, - -4.22065007e-03, -6.05447078e-03, 8.92056432e-03 + -1.09830676e-02, -3.97811923e-03, -7.55793473e-04, + 7.91002903e-03, -7.02103321e-03, -3.80714820e-03, + -4.21818346e-03, -6.05497835e-03, 8.92084371e-03 ], [ - 8.68127123e-03, -7.16970162e-03, -1.88375649e-03, - 1.62681788e-02, -1.76830534e-02, -1.06617520e-02, - -3.07536125e-03, -8.45551898e-04, -9.99375992e-03 + 8.68122280e-03, -7.16937613e-03, -1.88389909e-03, + 1.62679367e-02, -1.76828820e-02, -1.06622437e-02, + -3.07524228e-03, -8.46030540e-04, -9.99389403e-03 ], [ - -5.71034756e-03, 4.50129062e-03, -8.07590690e-03, - -6.94835978e-03, 6.75921654e-03, -5.12148207e-03, - 3.06083867e-03, -4.61710012e-03, -3.23932176e-03 + -5.71245840e-03, 4.50045895e-03, -8.07614625e-03, + -6.94804778e-03, 6.75577158e-03, -5.12094703e-03, + 3.06193763e-03, -4.61703911e-03, -3.23943049e-03 ], [ - -2.37224493e-02, -5.88587578e-03, -3.70525813e-05, - 2.01787278e-02, -6.76127791e-04, 1.06675029e-02, - -1.42634306e-02, -2.69631632e-02, -1.45033058e-02 + -2.37237271e-02, -5.88475820e-03, -3.73612711e-05, + 2.01791357e-02, -6.75620860e-04, 1.06695695e-02, + -1.42616741e-02, -2.69626491e-02, -1.45035451e-02 ], [ 1.12585640e-02, -6.92534202e-04, 9.88917705e-03, 2.75237625e-03, -6.56115822e-03, -1.57997780e-03, @@ -713,17 +713,17 @@ class AttentionWrapperTest(test.TestCase): dtype=float32)), attention=array( [[ - 0.00165033, 0.00196973, 0.00203463, -0.00382007, -0.00346369, - -0.00654224 + 0.00208854, 0.00213833, 0.00249781, -0.0035285, -0.00396897, + -0.00612035 ], [ - 0.00517205, -0.00216447, 0.00353219, -0.00286491, 0.00317879, - -0.00517592 + 0.00530425, -0.00217913, 0.0035951, -0.00297107, 0.0032645, + -0.00531189 ], [ - -0.00111966, -0.00060763, -0.00196644, -0.00186804, -0.00493048, - -0.00925842 + -0.00112129, -0.00081516, -0.00201535, -0.00189556, -0.00504226, + -0.00937189 ], [ - 0.00096792, -0.00432467, -0.00140086, 0.00063597, -0.00215558, - 0.00659212 + 0.00102683, -0.00443788, -0.00143418, 0.00065628, -0.00217684, + 0.00680128 ], [ 0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734, -0.00026981 @@ -745,41 +745,41 @@ class AttentionWrapperTest(test.TestCase): expected_final_output = BasicDecoderOutput( rnn_output=array( [[[ - -0.24277855, -0.07872247, 0.15671003, 0.24866608, 0.30515286, - -0.24618624, 0.0507621, 0.18785201, -0.16496147, -0.138405 + -0.24223405, -0.07791166, 0.15451428, 0.24738294, 0.30900395, + -0.24685201, 0.04992372, 0.18749543, -0.15878429, -0.13678923 ], [ - -0.24277849, -0.07872227, 0.15671015, 0.24866594, 0.30515245, - -0.24618617, 0.05076224, 0.18785192, -0.16496134, -0.13840486 + -0.2422339, -0.07791159, 0.15451418, 0.24738279, 0.30900383, + -0.24685188, 0.04992369, 0.18749531, -0.15878411, -0.13678911 ], [ - -0.24277903, -0.07872239, 0.15671065, 0.24866652, 0.30515283, - -0.24618667, 0.05076243, 0.18785232, -0.16496192, -0.13840519 + -0.2422343, -0.07791215, 0.15451413, 0.24738336, 0.30900475, + -0.2468522, 0.04992349, 0.18749571, -0.158785, -0.13678965 ]], [[ - 0.39683789, 0.12260036, -0.06023827, -0.09247135, 0.11370862, - -0.1547074, 0.00654875, -0.26491123, 0.08399884, 0.1876322 + 0.40035266, 0.12299616, -0.06085059, -0.09197108, 0.11368551, + -0.15302914, 0.00566157, -0.26885766, 0.08546552, 0.18886778 ], [ - 0.39683688, 0.12260011, -0.06023812, -0.09247123, 0.11370841, - -0.15470725, 0.00654882, -0.26491043, 0.08399857, 0.18763176 + 0.40035242, 0.12299603, -0.06085056, -0.09197091, 0.11368536, + -0.15302882, 0.0056615, -0.26885763, 0.08546554, 0.18886763 ], [ - 0.39683694, 0.12260016, -0.06023812, -0.09247129, 0.11370847, - -0.1547074, 0.00654885, -0.2649104, 0.08399855, 0.18763182 + 0.40035242, 0.122996, -0.06085056, -0.09197087, 0.11368532, + -0.1530287, 0.00566146, -0.26885769, 0.08546556, 0.18886761 ]], [[ - -0.432805, 0.07398784, -0.01561836, 0.19199517, -0.02651545, - -0.21643993, -0.02017856, 0.00162333, 0.21297953, 0.25590748 + -0.4311333, 0.07519469, -0.01551808, 0.1913045, -0.02693807, + -0.21668895, -0.02155721, 0.0013397, 0.21180844, 0.25578707 ], [ - -0.43280441, 0.07398778, -0.01561838, 0.19199494, -0.02651539, - -0.2164396, -0.02017859, 0.00162333, 0.2129792, 0.25590718 + -0.43113309, 0.07519454, -0.01551818, 0.19130446, -0.0269379, + -0.21668854, -0.021557, 0.00133975, 0.21180828, 0.25578681 ], [ - -0.43280473, 0.07398786, -0.01561838, 0.19199508, -0.02651543, - -0.21643978, -0.02017862, 0.00162333, 0.21297936, 0.25590736 + -0.43113324, 0.07519463, -0.01551815, 0.1913045, -0.02693798, + -0.21668874, -0.02155712, 0.00133973, 0.21180835, 0.25578696 ]], [[ - 0.07059769, 0.16451193, 0.01174642, 0.04646424, 0.14275651, - 0.07944378, -0.10852743, 0.15305836, 0.02151343, -0.05589932 + 0.07059932, 0.16451572, 0.01174669, 0.04646531, 0.1427598, + 0.0794456, -0.10852993, 0.15306188, 0.02151393, -0.05590061 ], [ - 0.07059769, 0.16451195, 0.01174642, 0.04646424, 0.14275652, - 0.07944378, -0.10852744, 0.15305838, 0.02151344, -0.05589932 + 0.07059933, 0.16451576, 0.01174669, 0.04646532, 0.14275983, + 0.07944562, -0.10852996, 0.15306193, 0.02151394, -0.05590062 ], [ - 0.07059769, 0.16451193, 0.01174642, 0.04646424, 0.14275651, - 0.07944378, -0.10852743, 0.15305836, 0.02151343, -0.05589932 + 0.07059937, 0.16451585, 0.0117467, 0.04646534, 0.1427599, + 0.07944567, -0.10853001, 0.153062, 0.02151395, -0.05590065 ]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]], @@ -792,20 +792,19 @@ class AttentionWrapperTest(test.TestCase): cell_state=LSTMStateTuple( c=array( [[ - -0.01816516, -0.01675301, -0.00513183, 0.01560374, - -0.01254231, -0.00434179, -0.01262734, -0.01721472, - 0.02279618 + -0.0181195, -0.01675365, -0.00510353, 0.01559796, + -0.01251448, -0.00437002, -0.01243257, -0.01720199, + 0.02274928 ], [ - 0.01265936, -0.00846189, -0.00376618, 0.0313628, - -0.03488337, -0.02464793, -0.00498007, -0.00459409, - -0.02097382 + 0.01259979, -0.00839985, -0.00374037, 0.03136262, + -0.03486227, -0.02466441, -0.00496157, -0.00461032, + -0.02098336 ], [ - -0.00780995, 0.00314269, -0.01383716, -0.01147376, - 0.00483319, -0.01340735, 0.00865264, -0.00632812, - -0.01053766 + -0.00781067, 0.00315682, -0.0138283, -0.01149793, + 0.00485562, -0.01343193, 0.0085915, -0.00632846, -0.01052086 ], [ - -0.04184838, -0.01223641, 0.00094449, 0.03911438, - 0.00432476, 0.0222066, -0.0300624, -0.05418365, -0.0261539 + -0.04184828, -0.01223641, 0.0009445, 0.03911434, 0.0043249, + 0.02220661, -0.03006243, -0.05418363, -0.02615385 ], [ 0.02282745, -0.00143833, 0.01918138, 0.00545033, -0.01258384, -0.00303765, -0.00511231, 0.02166323, @@ -814,20 +813,19 @@ class AttentionWrapperTest(test.TestCase): dtype=float32), h=array( [[ - -0.00912324, -0.00827531, -0.00261125, 0.00765155, - -0.00636977, -0.00217171, -0.00643897, -0.00876094, - 0.01136858 + -0.00910065, -0.00827571, -0.00259689, 0.00764857, + -0.00635579, -0.00218579, -0.00633918, -0.00875511, + 0.01134532 ], [ - 0.00629574, -0.00427225, -0.00182555, 0.01597121, - -0.01734862, -0.01224119, -0.00245434, -0.0023048, - -0.01077694 + 0.00626597, -0.004241, -0.00181303, 0.01597157, -0.0173375, + -0.01224921, -0.00244522, -0.00231299, -0.0107822 ], [ - -0.00391358, 0.00161294, -0.00683057, -0.00569066, - 0.0024297, -0.00658555, 0.00429511, -0.00309842, -0.00520863 + -0.00391383, 0.00162017, -0.00682621, -0.00570264, + 0.00244099, -0.00659772, 0.00426475, -0.00309861, + -0.00520028 ], [ - -0.02087489, -0.00603306, 0.00047561, 0.01920064, - 0.00213868, 0.01115329, -0.01526589, -0.02687524, - -0.01297526 + -0.02087484, -0.00603306, 0.00047561, 0.01920062, + 0.00213875, 0.01115329, -0.0152659, -0.02687523, -0.01297523 ], [ 0.01138975, -0.00070959, 0.00986007, 0.0027323, -0.00645386, -0.00152054, -0.00257339, 0.01103063, 0.00800891 @@ -835,17 +833,17 @@ class AttentionWrapperTest(test.TestCase): dtype=float32)), attention=array( [[ - -0.24277903, -0.07872239, 0.15671065, 0.24866652, 0.30515283, - -0.24618667, 0.05076243, 0.18785232, -0.16496192, -0.13840519 + -0.2422343, -0.07791215, 0.15451413, 0.24738336, 0.30900475, + -0.2468522, 0.04992349, 0.18749571, -0.158785, -0.13678965 ], [ - 0.39683694, 0.12260016, -0.06023812, -0.09247129, 0.11370847, - -0.1547074, 0.00654885, -0.2649104, 0.08399855, 0.18763182 + 0.40035242, 0.122996, -0.06085056, -0.09197087, 0.11368532, + -0.1530287, 0.00566146, -0.26885769, 0.08546556, 0.18886761 ], [ - -0.43280473, 0.07398786, -0.01561838, 0.19199508, -0.02651543, - -0.21643978, -0.02017862, 0.00162333, 0.21297936, 0.25590736 + -0.43113324, 0.07519463, -0.01551815, 0.1913045, -0.02693798, + -0.21668874, -0.02155712, 0.00133973, 0.21180835, 0.25578696 ], [ - 0.07059769, 0.16451193, 0.01174642, 0.04646424, 0.14275651, - 0.07944378, -0.10852743, 0.15305836, 0.02151343, -0.05589932 + 0.07059937, 0.16451585, 0.0117467, 0.04646534, 0.1427599, + 0.07944567, -0.10853001, 0.153062, 0.02151395, -0.05590065 ], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), time=3, diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 732b0db930..cb0cb4f8c3 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -245,10 +245,14 @@ class BeamSearchDecoderTest(test.TestCase): if has_attention: inputs = np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32) + tiled_inputs = beam_search_decoder.tile_batch( + inputs, multiplier=beam_width) + tiled_sequence_length = beam_search_decoder.tile_batch( + encoder_sequence_length, multiplier=beam_width) attention_mechanism = attention_wrapper.BahdanauAttention( num_units=attention_depth, - memory=inputs, - memory_sequence_length=encoder_sequence_length) + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) cell = attention_wrapper.AttentionWrapper( cell=cell, attention_mechanism=attention_mechanism, diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 3a3d0b8aec..04b38159bb 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as layers_base from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -72,6 +73,8 @@ def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined): """ memory = nest.map_structure( lambda m: ops.convert_to_tensor(m, name="memory"), memory) + memory_sequence_length = ops.convert_to_tensor( + memory_sequence_length, name="memory_sequence_length") if check_inner_dims_defined: def _check_dims(m): if not m.get_shape()[2:].is_fully_defined(): @@ -85,15 +88,24 @@ def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined): memory_sequence_length, maxlen=array_ops.shape(nest.flatten(memory)[0])[1], dtype=nest.flatten(memory)[0].dtype) + seq_len_batch_size = ( + memory_sequence_length.shape[0].value + or array_ops.shape(memory_sequence_length)[0]) def _maybe_mask(m, seq_len_mask): rank = m.get_shape().ndims rank = rank if rank is not None else array_ops.rank(m) extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) + m_batch_size = m.shape[0].value or array_ops.shape(m)[0] if memory_sequence_length is not None: - seq_len_mask = array_ops.reshape( - seq_len_mask, - array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) - return m * seq_len_mask + message = ("memory_sequence_length and memory tensor batch sizes do not " + "match.") + with ops.control_dependencies([ + check_ops.assert_equal( + seq_len_batch_size, m_batch_size, message=message)]): + seq_len_mask = array_ops.reshape( + seq_len_mask, + array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) + return m * seq_len_mask else: return m return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) @@ -108,7 +120,8 @@ class _BaseAttentionMechanism(AttentionMechanism): """ def __init__(self, query_layer, memory, memory_sequence_length=None, - memory_layer=None, check_inner_dims_defined=True, name=None): + memory_layer=None, check_inner_dims_defined=True, + name=None): """Construct base AttentionMechanism class. Args: @@ -147,6 +160,8 @@ class _BaseAttentionMechanism(AttentionMechanism): self._keys = ( self.memory_layer(self._values) if self.memory_layer # pylint: disable=not-callable else self._values) + self._batch_size = ( + self._keys.shape[0].value or array_ops.shape(self._keys)[0]) @property def memory_layer(self): @@ -164,6 +179,10 @@ class _BaseAttentionMechanism(AttentionMechanism): def keys(self): return self._keys + @property + def batch_size(self): + return self._batch_size + class LuongAttention(_BaseAttentionMechanism): """Implements Luong-style (multiplicative) attention scoring. @@ -213,14 +232,12 @@ class LuongAttention(_BaseAttentionMechanism): self._scale = scale self._name = name - def __call__(self, query, tiling_factor=1): + def __call__(self, query): """Score the query based on the keys and values. Args: query: Tensor of dtype matching `self.values` and shape `[batch_size, query_depth]`. - tiling_factor: An integer factor for which to tile the batch dimension. - Used with BeamSearchDecoder. Returns: score: Tensor of dtype matching `self.values` and shape @@ -317,14 +334,12 @@ class BahdanauAttention(_BaseAttentionMechanism): self._normalize = normalize self._name = name - def __call__(self, query, tiling_factor=1): + def __call__(self, query): """Score the query based on the keys and values. Args: query: Tensor of dtype matching `self.values` and shape `[batch_size, query_depth]`. - tiling_factor: An integer factor for which to tile the batch dimension. - Used with BeamSearchDecoder. Returns: score: Tensor of dtype matching `self.values` and shape @@ -335,7 +350,7 @@ class BahdanauAttention(_BaseAttentionMechanism): dtype = processed_query.dtype # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. processed_query = array_ops.expand_dims(processed_query, 1) - keys = _maybe_tile_batch(self.keys, tiling_factor) + keys = self._keys v = variable_scope.get_variable( "attention_v", [self._num_units], dtype=dtype) if self._normalize: @@ -428,6 +443,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell): cell_input_fn=None, probability_fn=None, output_attention=True, + initial_cell_state=None, name=None): """Construct the `AttentionWrapper`. @@ -454,6 +470,11 @@ class AttentionWrapper(core_rnn_cell.RNNCell): propagated to the next time step via the state and is used there. This flag only controls whether the attention mechanism is propagated up to the next cell in an RNN stack or to the top RNN output. + initial_cell_state: The initial state value to use for the cell when + the user calls `zero_state()`. Note that if this value is provided + now, and the user uses a `batch_size` argument of `zero_state` which + does not match the batch size of `initial_cell_state`, proper + behavior is not guaranteed. name: Name to use when creating ops. """ super(AttentionWrapper, self).__init__() @@ -475,7 +496,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell): if probability_fn is None: probability_fn = nn_ops.softmax else: - if not callable(probability_fn): + if not callable(cell_input_fn): raise TypeError( "probability_fn must be callable, saw type: %s" % type(probability_fn).__name__) @@ -494,6 +515,28 @@ class AttentionWrapper(core_rnn_cell.RNNCell): self._probability_fn = probability_fn self._output_attention = output_attention self._alignment_history = alignment_history + with ops.name_scope(name, "AttentionWrapperInit"): + if initial_cell_state is None: + self._initial_cell_state = None + else: + final_state_tensor = nest.flatten(initial_cell_state)[-1] + state_batch_size = ( + final_state_tensor.shape[0].value + or array_ops.shape(final_state_tensor)[0]) + error_message = ( + "When constructing AttentionWrapper %s: " % self._base_name + + "Non-matching batch sizes between the memory " + "(encoder output) and initial_cell_state. Are you using " + "the BeamSearchDecoder? You may need to tile your initial state " + "via the tf.contrib.seq2seq.tile_batch function with argument " + "multiple=beam_width.") + with ops.control_dependencies( + [check_ops.assert_equal(state_batch_size, + self._attention_mechanism.batch_size, + message=error_message)]): + self._initial_cell_state = nest.map_structure( + lambda s: array_ops.identity(s, name="check_initial_cell_state"), + initial_cell_state) @property def output_size(self): @@ -512,19 +555,38 @@ class AttentionWrapper(core_rnn_cell.RNNCell): def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + if self._initial_cell_state is not None: + cell_state = self._initial_cell_state + else: + cell_state = self._cell.zero_state(batch_size, dtype) + error_message = ( + "When calling zero_state of AttentionWrapper %s: " % self._base_name + + "Non-matching batch sizes between the memory " + "(encoder output) and the requested batch size. Are you using " + "the BeamSearchDecoder? If so, make sure your encoder output has " + "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and " + "the batch_size= argument passed to zero_state is " + "batch_size * beam_width.") + with ops.control_dependencies( + [check_ops.assert_equal(batch_size, + self._attention_mechanism.batch_size, + message=error_message)]): + cell_state = nest.map_structure( + lambda s: array_ops.identity(s, name="checked_cell_state"), + cell_state) if self._alignment_history: alignment_history = tensor_array_ops.TensorArray( dtype=dtype, size=0, dynamic_size=True) else: alignment_history = () return AttentionWrapperState( - cell_state=self._cell.zero_state(batch_size, dtype), + cell_state=cell_state, time=array_ops.zeros([], dtype=dtypes.int32), attention=_zero_state_tensors(self._attention_size, batch_size, dtype), alignment_history=alignment_history) - def __call__(self, inputs, state, tiling_factor=1): + def call(self, inputs, state): """Perform a step of attention-wrapped RNN. - Step 1: Mix the `inputs` and previous step's `attention` output via @@ -543,8 +605,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell): inputs: (Possibly nested tuple of) Tensor, the input at this time step. state: An instance of `AttentionWrapperState` containing tensors from the previous time step. - tiling_factor: An integer factor for which to tile the batch dimension. - Used with BeamSearchDecoder. Returns: A tuple `(attention_or_cell_output, next_state)`, where: @@ -552,81 +612,67 @@ class AttentionWrapper(core_rnn_cell.RNNCell): - `attention_or_cell_output` depending on `output_attention`. - `next_state` is an instance of `DynamicAttentionWrapperState` containing the state calculated at this time step. - - Raises: - NotImplementedError: if `scope` is not `None`. """ - # Step 1: Calculate the true inputs to the cell based on the - # previous attention value. - cell_inputs = self._cell_input_fn(inputs, state.attention) - cell_state = state.cell_state - cell_output, next_cell_state = self._cell(cell_inputs, cell_state) - - score = self._attention_mechanism(cell_output, tiling_factor) - alignments = self._probability_fn(score) - - # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] - expanded_alignments = array_ops.expand_dims(alignments, 1) - # Context is the inner product of alignments and values along the - # memory time dimension. - # alignments shape is - # [batch_size, 1, memory_time] - # attention_mechanism.values shape is - # [batch_size, memory_time, attention_values_dim] - # the batched matmul is over memory_time, so the output shape is - # [batch_size, 1, attention_values_dim]. - # we then squeeze out the singleton dim. - attention_mechanism_values = _maybe_tile_batch( - self._attention_mechanism.values, tiling_factor) - - context = math_ops.matmul(expanded_alignments, attention_mechanism_values) - context = array_ops.squeeze(context, [1]) - - if self._attention_layer is not None: - attention = self._attention_layer( - array_ops.concat([cell_output, context], 1)) - else: - attention = context + with variable_scope.variable_scope("attention"): + # Step 1: Calculate the true inputs to the cell based on the + # previous attention value. + cell_inputs = self._cell_input_fn(inputs, state.attention) + cell_state = state.cell_state + cell_output, next_cell_state = self._cell(cell_inputs, cell_state) + + cell_batch_size = ( + cell_output.shape[0].value or array_ops.shape(cell_output)[0]) + error_message = ( + "When applying AttentionWrapper %s: " % self.name + + "Non-matching batch sizes between the memory " + "(encoder output) and the query (decoder output). Are you using " + "the BeamSearchDecoder? You may need to tile your memory input via " + "the tf.contrib.seq2seq.tile_batch function with argument " + "multiple=beam_width.") + with ops.control_dependencies( + [check_ops.assert_equal(cell_batch_size, + self._attention_mechanism.batch_size, + message=error_message)]): + cell_output = array_ops.identity( + cell_output, name="checked_cell_output") + + score = self._attention_mechanism(cell_output) + alignments = self._probability_fn(score) + + # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] + expanded_alignments = array_ops.expand_dims(alignments, 1) + # Context is the inner product of alignments and values along the + # memory time dimension. + # alignments shape is + # [batch_size, 1, memory_time] + # attention_mechanism.values shape is + # [batch_size, memory_time, attention_mechanism.num_units] + # the batched matmul is over memory_time, so the output shape is + # [batch_size, 1, attention_mechanism.num_units]. + # we then squeeze out the singleton dim. + attention_mechanism_values = self._attention_mechanism.values + context = math_ops.matmul(expanded_alignments, attention_mechanism_values) + context = array_ops.squeeze(context, [1]) + + if self._attention_layer is not None: + attention = self._attention_layer( + array_ops.concat([cell_output, context], 1)) + else: + attention = context - if self._alignment_history: - alignment_history = state.alignment_history.write( - state.time, alignments) - else: - alignment_history = () + if self._alignment_history: + alignment_history = state.alignment_history.write( + state.time, alignments) + else: + alignment_history = () - next_state = AttentionWrapperState( - time=state.time + 1, - cell_state=next_cell_state, - attention=attention, - alignment_history=alignment_history) + next_state = AttentionWrapperState( + time=state.time + 1, + cell_state=next_cell_state, + attention=attention, + alignment_history=alignment_history) if self._output_attention: return attention, next_state else: return cell_output, next_state - - -def _maybe_tile_batch(t, tiling_factor): - """Tile the tensor's batch by tiling_factor. - - Here, we tile t such that it looks like [b1, b1, ..., ..., bN, bN, ...]. - - Args: - t: The tensor to tile. - tiling_factor: The amount to tile it. - - Returns: - The tiled tensor. - """ - if tiling_factor == 1: - return t - - shape = t.get_shape().as_list() - shape = [shape[0] * tiling_factor] + shape[1:] - tile_values = len(shape)*[1] - tile_values.insert(1, tiling_factor) - t = array_ops.expand_dims(t, 1) - t = array_ops.tile(t, tile_values) - t = array_ops.reshape(t, shape) - t.set_shape(shape) - return t diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 55ef21a5a0..289da8e6ae 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest @@ -41,6 +42,7 @@ __all__ = [ "BeamSearchDecoderState", "BeamSearchDecoder", "FinalBeamSearchDecoderOutput", + "tile_batch", ] @@ -70,6 +72,44 @@ class FinalBeamSearchDecoderOutput( pass +def tile_batch(t, multiplier, name=None): + """Tile the batch dimension of tensor t. + + This function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of + minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape + `[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries + `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated + `multiplier` times. + + Args: + t: `Tensor` shaped `[batch_size, ...]`. + multiplier: Python int. + name: Name scope for any created operations. + + Returns: + A `Tensor` shaped `[batch_size * multiplier, ...]`. + + Raises: + ValueError: if `t` does not have a statically known rank or it's < 1. + """ + with ops.name_scope(name, "tile_batch", [t, multiplier]): + t = ops.convert_to_tensor(t, name="t") + shape_t = array_ops.shape(t) + if t.shape.ndims is None or t.shape.ndims < 1: + raise ValueError("t must have statically known rank") + tiling = [1] * (t.shape.ndims + 1) + tiling[1] = multiplier + tiled_static_batch_size = ( + t.shape[0].value * multiplier if t.shape[0].value is not None else None) + tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling) + tiled = array_ops.reshape( + tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0)) + tiled.set_shape( + tensor_shape.TensorShape( + [tiled_static_batch_size]).concatenate(t.shape[1:])) + return tiled + + class BeamSearchDecoder(decoder.Decoder): """BeamSearch sampling decoder.""" @@ -130,8 +170,9 @@ class BeamSearchDecoder(decoder.Decoder): self._batch_size = array_ops.size(start_tokens) self._beam_width = beam_width self._length_penalty_weight = length_penalty_weight - self._initial_cell_state = nest.map_structure(self._maybe_split_batch_beams, - initial_state) + self._initial_cell_state = nest.map_structure( + self._maybe_split_batch_beams, + initial_state, self._cell.state_size) self._start_tokens = array_ops.tile( array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) self._start_inputs = self._embedding_fn(self._start_tokens) @@ -223,19 +264,23 @@ class BeamSearchDecoder(decoder.Decoder): beam_search_decoder_output=outputs, predicted_ids=predicted_ids) return outputs, final_state - def _merge_batch_beams(self, t): + def _merge_batch_beams(self, t, s=None): """Merges the tensor from a batch of beams into a batch by beams. - More exactly, t is a tensor of dimension [batch_size, beam_width, ...]. We - reshape this into [batch_size*beam_width, ...] + More exactly, t is a tensor of dimension [batch_size, beam_width, s]. We + reshape this into [batch_size*beam_width, s] Args: - t: Tensor of dimension [batch_size, beam_width, ...] + t: Tensor of dimension [batch_size, beam_width, s] + s: (Possibly known) depth shape. Returns: - A reshaped version of t with dimension [batch_size * beam_width, ...]. + A reshaped version of t with dimension [batch_size * beam_width, s]. """ - t_static_shape = t.shape + if isinstance(s, ops.Tensor): + s = tensor_util.constant_value_as_shape(s) + else: + s = tensor_shape.TensorShape(s) t_shape = array_ops.shape(t) static_batch_size = tensor_util.constant_value(self._batch_size) batch_size_beam_width = ( @@ -245,67 +290,105 @@ class BeamSearchDecoder(decoder.Decoder): t, array_ops.concat( ([self._batch_size * self._beam_width], t_shape[2:]), 0)) reshaped_t.set_shape( - (tensor_shape.TensorShape([batch_size_beam_width]) - .concatenate(t_static_shape[2:]))) + (tensor_shape.TensorShape([batch_size_beam_width]).concatenate(s))) return reshaped_t - def _split_batch_beams(self, t): + def _split_batch_beams(self, t, s=None): """Splits the tensor from a batch by beams into a batch of beams. - More exactly, t is a tensor of dimension [batch_size*beam_width, ...]. We - reshape this into [batch_size, beam_width, ...] + More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We + reshape this into [batch_size, beam_width, s] Args: - t: Tensor of dimension [batch_size*beam_width, ...] + t: Tensor of dimension [batch_size*beam_width, s]. + s: (Possibly known) depth shape. Returns: - A reshaped version of t with dimension [batch_size, beam_width, ...]. + A reshaped version of t with dimension [batch_size, beam_width, s]. + + Raises: + ValueError: If, after reshaping, the new tensor is not shaped + `[batch_size, beam_width, s]` (assuming batch_size and beam_width + are known statically). """ - t_static_shape = t.shape + if isinstance(s, ops.Tensor): + s = tensor_util.constant_value_as_shape(s) + else: + s = tensor_shape.TensorShape(s) t_shape = array_ops.shape(t) reshaped_t = array_ops.reshape( t, array_ops.concat( ([self._batch_size, self._beam_width], t_shape[1:]), 0)) static_batch_size = tensor_util.constant_value(self._batch_size) - reshaped_t.set_shape( - (tensor_shape.TensorShape([static_batch_size, self._beam_width]) - .concatenate(t_static_shape[1:]))) + expected_reshaped_shape = tensor_shape.TensorShape( + [static_batch_size, self._beam_width]).concatenate(s) + if not reshaped_t.shape.is_compatible_with(expected_reshaped_shape): + raise ValueError("Unexpected behavior when reshaping between beam width " + "and batch size. The reshaped tensor has shape: %s. " + "We expected it to have shape " + "(batch_size, beam_width, depth) == %s. Perhaps you " + "forgot to create a zero_state with " + "batch_size=encoder_batch_size * beam_width?" + % (reshaped_t.shape, expected_reshaped_shape)) + reshaped_t.set_shape(expected_reshaped_shape) return reshaped_t - def _maybe_split_batch_beams(self, t): + def _maybe_split_batch_beams(self, t, s): """Maybe splits the tensor from a batch by beams into a batch of beams. We do this so that we can use nest and not run into problems with shapes. Args: - t: Tensor of dimension [batch_size*beam_width, ...] + t: Tensor of dimension [batch_size*beam_width, s] + s: Tensor, Python int, or TensorShape. Returns: Either a reshaped version of t with dimension - [batch_size, beam_width, ...] if t's first dimension is of size + [batch_size, beam_width, s] if t's first dimension is of size batch_size*beam_width or t if not. + + Raises: + TypeError: If t is an instance of TensorArray. + ValueError: If the rank of t is not statically known. """ - t_shape = t.get_shape().as_list() - if len(t_shape) >= 1: - return self._split_batch_beams(t) + if isinstance(t, tensor_array_ops.TensorArray): + raise TypeError( + "TensorArray state is not supported by BeamSearchDecoder: %s" + % t.name) + if t.shape.ndims is None: + raise ValueError( + "Expected tensor (%s) to have known rank, but ndims == None." % t) + if t.shape.ndims >= 1: + return self._split_batch_beams(t, s) else: return t - def _maybe_merge_batch_beams(self, t): + def _maybe_merge_batch_beams(self, t, s): """Splits the tensor from a batch by beams into a batch of beams. - More exactly, t is a tensor of dimension [batch_size*beam_width, ...]. We - reshape this into [batch_size, beam_width, ...] + More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We + reshape this into [batch_size, beam_width, s] Args: - t: Tensor of dimension [batch_size*beam_width, ...] + t: Tensor of dimension [batch_size*beam_width, s] + s: Tensor, Python int, or TensorShape. Returns: - A reshaped version of t with dimension [batch_size, beam_width, ...]. + A reshaped version of t with dimension [batch_size, beam_width, s]. + + Raises: + TypeError: If t is an instance of TensorArray. + ValueError: If the rank of t is not statically known. """ - t_shape = t.get_shape().as_list() - if len(t_shape) >= 2: - return self._merge_batch_beams(t) + if isinstance(t, tensor_array_ops.TensorArray): + raise TypeError( + "TensorArray state is not supported by BeamSearchDecoder: %s" + % t.name) + if t.shape.ndims is None: + raise ValueError( + "Expected tensor (%s) to have known rank, but ndims == None." % t) + if t.shape.ndims >= 2: + return self._merge_batch_beams(t, s) else: return t @@ -328,20 +411,18 @@ class BeamSearchDecoder(decoder.Decoder): with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): cell_state = state.cell_state - inputs = nest.map_structure(self._merge_batch_beams, inputs) - cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state) - try: - cell_outputs, next_cell_state = self._cell( - inputs, cell_state, tiling_factor=beam_width) - except TypeError as e: - if "unexpected keyword argument 'tiling_factor'" in str(e): - cell_outputs, next_cell_state = self._cell(inputs, cell_state) - else: - raise - - cell_outputs = nest.map_structure(self._split_batch_beams, cell_outputs) - next_cell_state = nest.map_structure(self._maybe_split_batch_beams, - next_cell_state) + inputs = nest.map_structure( + lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) + cell_state = nest.map_structure( + self._maybe_merge_batch_beams, + cell_state, self._cell.state_size) + cell_outputs, next_cell_state = self._cell(inputs, cell_state) + + cell_outputs = nest.map_structure( + lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) + next_cell_state = nest.map_structure( + self._maybe_split_batch_beams, + next_cell_state, self._cell.state_size) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) -- GitLab From 0293b46e724b85546a2175fdcb3992f44f3c0ef4 Mon Sep 17 00:00:00 2001 From: Vijay Vasudevan Date: Wed, 26 Apr 2017 12:03:45 -0800 Subject: [PATCH 012/697] AllocationRegistry: only check fail if two different allocator types are defined for the same name and priority. Change: 154333776 --- .../core/framework/allocator_registry.cc | 27 ++++++++++++++----- .../core/framework/allocator_registry.h | 7 +++-- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/framework/allocator_registry.cc b/tensorflow/core/framework/allocator_registry.cc index 946050687d..486be39ae3 100644 --- a/tensorflow/core/framework/allocator_registry.cc +++ b/tensorflow/core/framework/allocator_registry.cc @@ -26,22 +26,37 @@ AllocatorRegistry* AllocatorRegistry::Global() { return global_allocator_registry; } -bool AllocatorRegistry::CheckForDuplicates(const string& name, int priority) { +Allocator* AllocatorRegistry::GetRegisteredAllocator(const string& name, + int priority) { for (auto entry : allocators_) { if (!name.compare(entry.name) && priority == entry.priority) { - return true; + return entry.allocator; } } - return false; + return nullptr; } void AllocatorRegistry::Register(const string& name, int priority, Allocator* allocator) { CHECK(!name.empty()) << "Need a valid name for Allocator"; CHECK_GE(priority, 0) << "Priority needs to be non-negative"; - CHECK(!CheckForDuplicates(name, priority)) - << "Allocator with name: [" << name << "] and priority: [" << priority - << "] already registered"; + + Allocator* existing = GetRegisteredAllocator(name, priority); + if (existing != nullptr) { + // A duplicate is if the registration name and priority match + // but the Allocator::Name()'s don't match. + CHECK_EQ(existing->Name(), allocator->Name()) + << "Allocator with name: [" << name << "], type [" << existing->Name() + << "], priority: [" << priority + << "] already registered. Choose a different name to register " + << "an allocator of type " << allocator->Name(); + + // The allocator names match, so we can just return. + // It should be safe to delete the allocator since the caller + // gives up ownership of it. + delete allocator; + return; + } AllocatorRegistryEntry tmp_entry; tmp_entry.name = name; diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h index c419366ae1..b26e79ac3b 100644 --- a/tensorflow/core/framework/allocator_registry.h +++ b/tensorflow/core/framework/allocator_registry.h @@ -27,7 +27,8 @@ namespace tensorflow { // A global AllocatorRegistry is used to hold allocators for CPU backends class AllocatorRegistry { public: - // Add an allocator to the registry. + // Add an allocator to the registry. Caller releases ownership of + // 'allocator'. void Register(const string& name, int priority, Allocator* allocator); // Return allocator with highest priority @@ -44,7 +45,9 @@ class AllocatorRegistry { Allocator* allocator; // not owned } AllocatorRegistryEntry; - bool CheckForDuplicates(const string& name, int priority); + // Returns the Allocator registered for 'name' and 'priority', + // or 'nullptr' if not found. + Allocator* GetRegisteredAllocator(const string& name, int priority); std::vector allocators_; Allocator* m_curr_allocator_; // not owned -- GitLab From 4f0dacd1f1995414b6656f9398aca133cda1c860 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Apr 2017 12:36:21 -0800 Subject: [PATCH 013/697] Add a conversion from xla session module to a list of node structure as a step of xla computation transformation Change: 154337650 --- .../contrib/xla_tf_graph/xla_tf_graph_util.cc | 183 +++++++++++++++++- .../contrib/xla_tf_graph/xla_tf_graph_util.h | 29 +++ .../xla_tf_graph/xla_tf_graph_util_test.cc | 80 +++++++- 3 files changed, 289 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc index 3bad9b8067..f0dabc08a4 100644 --- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc +++ b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc @@ -17,13 +17,182 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/core/platform/protobuf.h" namespace tensorflow { namespace xla_tf_graph { namespace { -constexpr const char* const GRAPH_NAME = "xla_tf_graph_util"; +constexpr const char* const GRAPH_NAME = "xla_tf_graph"; +constexpr const char* const NODE_NAME_PREFIX = "xla"; + +Status ConvertPrimitiveTypeToDataType(const xla::PrimitiveType p_type, + DataType* d_type) { + switch (p_type) { + case xla::PRED: + *d_type = DT_BOOL; + return Status::OK(); + case xla::S8: + *d_type = DT_INT8; + return Status::OK(); + case xla::S16: + *d_type = DT_INT16; + return Status::OK(); + case xla::S32: + *d_type = DT_INT32; + return Status::OK(); + case xla::S64: + *d_type = DT_INT64; + return Status::OK(); + case xla::U8: + *d_type = DT_UINT8; + return Status::OK(); + case xla::U16: + *d_type = DT_UINT16; + return Status::OK(); + case xla::F16: + *d_type = DT_HALF; + return Status::OK(); + case xla::F32: + *d_type = DT_FLOAT; + return Status::OK(); + case xla::F64: + *d_type = DT_DOUBLE; + return Status::OK(); + default: + return errors::InvalidArgument( + "Unsupported PrimitiveType in ConvertPrimitiveTypeToDataType ", + xla::PrimitiveType_Name(p_type)); + } +} + +Status ConvertXlaShapeToTensorShapeType(const xla::Shape& xla_shape, + std::vector* tensor_shapes, + std::vector* data_types) { + switch (xla_shape.element_type()) { + case xla::TUPLE: { + for (const xla::Shape& element_shape : xla_shape.tuple_shapes()) { + if (element_shape.element_type() == xla::TUPLE) { + return errors::InvalidArgument("Nested tuple is not allowed."); + } + TF_RETURN_IF_ERROR(ConvertXlaShapeToTensorShapeType( + element_shape, tensor_shapes, data_types)); + } + return Status::OK(); + } + case xla::PRED: + case xla::S8: + case xla::S16: + case xla::S32: + case xla::S64: + case xla::U8: + case xla::U16: + case xla::U32: + case xla::U64: + case xla::F16: + case xla::F32: + case xla::F64: { + TensorShape shape; + DataType type; + TF_RETURN_IF_ERROR( + ConvertPrimitiveTypeToDataType(xla_shape.element_type(), &type)); + for (const int64& dim : xla_shape.dimensions()) { + shape.AddDim(dim); + } + tensor_shapes->emplace_back(shape); + data_types->emplace_back(type); + return Status::OK(); + } + default: + return errors::InvalidArgument( + "Unsupported PrimitiveType in ConvertXlaShapeToTensorShapeType ", + xla::PrimitiveType_Name(xla_shape.element_type())); + } +} + +string BuildXlaNodeName(const xla::OperationRequest& operation_request, + const string& xla_op_type, const string& suffix) { + const string name = strings::StrCat( + NODE_NAME_PREFIX, "/", operation_request.output_handle().handle(), "/", + xla_op_type); + if (suffix.empty()) { + return name; + } else { + return strings::StrCat(name, "/", suffix); + } +} + +string BuildXlaNodeName(const xla::OperationRequest& operation_request, + const string& xla_op_type) { + return BuildXlaNodeName(operation_request, xla_op_type, ""); +} + +string BuildXlaNodeOp(const protobuf::Message& msg, const string& suffix) { + return strings::StrCat(msg.GetDescriptor()->name(), "/", suffix); +} + +string BuildXlaNodeOp(const protobuf::Message& msg) { + return BuildXlaNodeOp(msg, ""); +} + +Status ConvertOpRequestToXlaNode(const xla::OperationRequest& operation_request, + XlaNode* xla_node) { + const xla::OpRequest& op_request = operation_request.request(); + switch (op_request.op_case()) { + case xla::OpRequest::kBinaryOpRequest: { + const xla::BinaryOpRequest& op = op_request.binary_op_request(); + xla_node->op_type = + BuildXlaNodeOp(op, xla::BinaryOperation_Name(op.binop())); + xla_node->name = BuildXlaNodeName(operation_request, xla_node->op_type); + xla_node->input_ids.emplace_back(std::make_tuple(op.lhs().handle(), 0)); + xla_node->input_ids.emplace_back(std::make_tuple(op.rhs().handle(), 0)); + for (const int64& dim : op.broadcast_dimensions()) { + xla_node->broadcast_dimensions.emplace_back(dim); + } + break; + } + case xla::OpRequest::kParameterRequest: { + const xla::ParameterRequest& op = op_request.parameter_request(); + xla_node->op_type = BuildXlaNodeOp(op, ""); + xla_node->name = + BuildXlaNodeName(operation_request, xla_node->op_type, op.name()); + break; + } + case xla::OpRequest::kVariadicOpRequest: { + const xla::VariadicOpRequest& op = op_request.variadic_op_request(); + xla_node->op_type = + BuildXlaNodeOp(op, xla::VariadicOperation_Name(op.varop())); + xla_node->name = BuildXlaNodeName(operation_request, xla_node->op_type); + for (const xla::ComputationDataHandle& handle : op.operands()) { + xla_node->input_ids.emplace_back(std::make_tuple(handle.handle(), 0)); + } + break; + } + case xla::OpRequest::kGetTupleElementRequest: { + const xla::GetTupleElementRequest& op = + op_request.get_tuple_element_request(); + xla_node->op_type = BuildXlaNodeOp(op); + xla_node->name = BuildXlaNodeName(operation_request, xla_node->op_type); + xla_node->input_ids.emplace_back( + std::make_tuple(op.operand().handle(), op.index())); + break; + } + default: + // TODO(satok): Implement all possible cases. + LOG(FATAL) << "Op request: " << op_request.op_case() + << " is not supported yet."; + break; + } + + CHECK(!xla_node->name.empty()); + CHECK(!xla_node->op_type.empty()); + + TF_RETURN_IF_ERROR(ConvertXlaShapeToTensorShapeType( + operation_request.output_shape(), &xla_node->output_shapes, + &xla_node->output_data_types)); + return Status::OK(); +} void SetupXlaCpuClient(std::unique_ptr* flib_def, std::unique_ptr* flr, @@ -67,5 +236,17 @@ ConvertTfGraphToXlaSessionModule(const std::vector& args, return result.computation.Snapshot(); } +xla::StatusOr> +ConvertXlaSessionModuleToXlaNodes(const xla::SessionModule& session_module) { + std::unordered_map xla_nodes; + for (const auto& operation_request : session_module.entry().requests()) { + XlaNode xla_node; + TF_RETURN_IF_ERROR( + ConvertOpRequestToXlaNode(operation_request.second, &xla_node)); + xla_nodes.emplace(operation_request.first, xla_node); + } + return std::move(xla_nodes); +} + } // namespace xla_tf_graph } // namespace tensorflow diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h index 89dca876b0..e635290851 100644 --- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h +++ b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_ #define TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_ +#include + #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -32,11 +34,38 @@ namespace xla_tf_graph { // implementing xla computations so that they can do experiments on their // specialized environments. +// A structure to represent typed attributes of TensorFlow graph node. +// This structure contains op specific attributes as members so that +// we can treat them explicitly. +struct XlaNode { + // Unique node name + string name; + // Op type of xla computation + string op_type; + // List of pair of unique id and port of input node. + // We store this value instead + // of node name in order not to wait for all XlaNodes to be constructed. + std::vector> input_ids; + // Oputput shapes + std::vector output_shapes; + // Output data types + std::vector output_data_types; + + //--------------------------- + // Op specific attributes + // #xla::OpRequest::kBinaryOpRequest + std::vector broadcast_dimensions; +}; + // Convert a tf graph to a xla session module xla::StatusOr> ConvertTfGraphToXlaSessionModule(const std::vector& args, std::unique_ptr graph); +// Convert a xla session module to a map to XlaNode from unique id +xla::StatusOr> +ConvertXlaSessionModuleToXlaNodes(const xla::SessionModule& session_module); + } // namespace xla_tf_graph } // namespace tensorflow diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc index bab4256187..23649957f3 100644 --- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc +++ b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/platform/test.h" @@ -27,6 +28,7 @@ static std::unique_ptr BuildAddGraph() { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); + // See tf2xla/kernels/binary_ops.cc auto c = ops::Add(scope.WithOpName("C"), a, b); auto d = ops::_Retval(scope.WithOpName("D"), c, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); @@ -34,16 +36,75 @@ static std::unique_ptr BuildAddGraph() { return graph; } -TEST(XlaTfGraphUtil, ConvertTfGraphToHloModule) { +static std::vector BuildAddGraphArguments() { // Builds a description of the arguments. std::vector args(2); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = TensorShape({2}); + // Difference of dimension will add extra broadcast_dimensions. + // broadcast_dimension generates an additional HloInstruction + // in user_computation.cc + args[0].shape = TensorShape({2, 2}); args[1].kind = XlaCompiler::Argument::kParameter; args[1].type = DT_INT32; args[1].shape = TensorShape({2}); + return args; +} + +// CAVEAT: Debug purpose only. +// This function dumps a protobuf string format of HloModule. +static void DumpHloGraphForDebug(const std::vector& args, + std::unique_ptr graph) { + std::unique_ptr flib_def; + std::unique_ptr flr; + std::unique_ptr compiler; + + xla::Client* client = xla::ClientLibrary::LocalClientOrDie(); + XlaOpRegistry::RegisterCompilationKernels(); + + FunctionDefLibrary flib; + flib_def.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); + + // Compiles the graph. + XlaCompiler::Options options; + options.device_type = DeviceType("XLA_CPU_JIT"); + options.client = client; + compiler.reset(new XlaCompiler(options)); + flr.reset(NewFunctionLibraryRuntime(compiler->device_mgr(), /*env=*/nullptr, + compiler->device(), TF_GRAPH_DEF_VERSION, + flib_def.get(), OptimizerOptions(), + /*custom_kernel_creator=*/nullptr)); + + // Compile graph + XlaCompiler::CompilationResult result; + TF_CHECK_OK(compiler->CompileGraph("dump", std::move(graph), flr.get(), args, + &result)); + + // Convert to hlo + xla::Computation& computation = result.computation; + + xla::Service* service( + static_cast(xla::ClientLibrary::GetXlaService( + static_cast(client)->platform()))); + const xla::ComputationTracker& computation_tracker = + service->computation_tracker(); + + auto user_computation_status = + computation_tracker.Resolve(computation.handle()); + TF_CHECK_OK(user_computation_status.status()); + auto user_computation = user_computation_status.ConsumeValueOrDie(); + xla::VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + std::unique_ptr hlo_module = std::move( + computation_tracker.BuildHloModule(versioned_handle).ValueOrDie()); + VLOG(1) << "--- DUMP HLO ---"; + VLOG(1) << hlo_module->ToString(); +} + +TEST(XlaTfGraphUtil, ConvertTfGraphToSessionModule) { + // Builds a description of the arguments. + std::vector args = BuildAddGraphArguments(); std::unique_ptr graph = BuildAddGraph(); TF_ASSIGN_OR_ASSERT_OK( @@ -51,6 +112,21 @@ TEST(XlaTfGraphUtil, ConvertTfGraphToHloModule) { ConvertTfGraphToXlaSessionModule(args, std::move(graph))); ASSERT_EQ(5, session_module->entry().requests_size()); + + VLOG(1) << "--- DUMP ---"; + VLOG(1) << session_module->DebugString(); + DumpHloGraphForDebug(args, BuildAddGraph()); +} + +TEST(XlaTfGraphUtil, ConvertXlaSessionModuleToXlaNodes) { + std::vector args = BuildAddGraphArguments(); + std::unique_ptr graph = BuildAddGraph(); + TF_ASSIGN_OR_ASSERT_OK( + std::unique_ptr session_module, + ConvertTfGraphToXlaSessionModule(args, std::move(graph))); + TF_ASSIGN_OR_ASSERT_OK(auto xla_nodes, + ConvertXlaSessionModuleToXlaNodes(*session_module)); + EXPECT_EQ(session_module->entry().requests_size(), xla_nodes.size()); } } // namespace xla_tf_graph -- GitLab From 05c77baecdafb2a44d3788b369a92cec66157162 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 26 Apr 2017 12:40:12 -0800 Subject: [PATCH 014/697] tfdbg: Improve the way uninitialized tensor values and DT_RESOURCE tensor values are displayed 1. Fix the bug wherein DT_RESOURCE tensors are displayed as "Uninitialized tensor", even if they are initialized. 2. For truly uninitialized tensors, provide more information by printing the TensorProto, which contains the dtype and shape. Change: 154338128 --- tensorflow/python/debug/BUILD | 1 + tensorflow/python/debug/cli/tensor_format.py | 5 +- .../python/debug/cli/tensor_format_test.py | 38 ++++++++++++-- tensorflow/python/debug/lib/debug_data.py | 49 ++++++++++++++----- .../python/debug/lib/debug_data_test.py | 13 +++-- .../python/debug/lib/session_debug_testlib.py | 11 +++-- 6 files changed, 93 insertions(+), 24 deletions(-) diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 2c8fa53501..f9c908f538 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -111,6 +111,7 @@ py_library( srcs = ["cli/tensor_format.py"], srcs_version = "PY2AND3", deps = [ + ":debug_data", ":debugger_cli_common", "//third_party/py/numpy", ], diff --git a/tensorflow/python/debug/cli/tensor_format.py b/tensorflow/python/debug/cli/tensor_format.py index c3c4bcf215..bb7ac31430 100644 --- a/tensorflow/python/debug/cli/tensor_format.py +++ b/tensorflow/python/debug/cli/tensor_format.py @@ -24,6 +24,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.debug.cli import debugger_cli_common +from tensorflow.python.debug.lib import debug_data _NUMPY_OMISSION = "...," _NUMPY_DEFAULT_EDGE_ITEMS = 3 @@ -112,10 +113,10 @@ def format_tensor(tensor, (8 + proper_len + 1, 8 + proper_len + 1 + debug_op_len, "yellow") ] - if tensor is None: + if isinstance(tensor, debug_data.InconvertibleTensorProto): if lines: lines.append("") - lines.append("Uninitialized tensor") + lines.extend(str(tensor).split("\n")) return debugger_cli_common.RichTextLines(lines) elif not isinstance(tensor, np.ndarray): # If tensor is not a np.ndarray, return simple text-line representation of diff --git a/tensorflow/python/debug/cli/tensor_format_test.py b/tensorflow/python/debug/cli/tensor_format_test.py index 8392a87367..ec80bb998e 100644 --- a/tensorflow/python/debug/cli/tensor_format_test.py +++ b/tensorflow/python/debug/cli/tensor_format_test.py @@ -20,7 +20,11 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.core.framework import tensor_pb2 +from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.core.framework import types_pb2 from tensorflow.python.debug.cli import tensor_format +from tensorflow.python.debug.lib import debug_data from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest @@ -363,10 +367,28 @@ class RichTextLinesTest(test_util.TensorFlowTestCase): if i < 1: self.assertNotIn(p + i * 6 + 5, out.annotations) - def testFormatNone(self): - out = tensor_format.format_tensor(None, "a") + def testFormatUninitializedTensor(self): + tensor_proto = tensor_pb2.TensorProto( + dtype=types_pb2.DataType.Value("DT_FLOAT"), + tensor_shape=tensor_shape_pb2.TensorShapeProto( + dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])) + out = tensor_format.format_tensor( + debug_data.InconvertibleTensorProto(tensor_proto, False), "a") + + self.assertEqual(["Tensor \"a\":", "", "Uninitialized tensor:"], + out.lines[:3]) + self.assertEqual(str(tensor_proto).split("\n"), out.lines[3:]) - self.assertEqual(["Tensor \"a\":", "", "Uninitialized tensor"], out.lines) + def testFormatResourceTypeTensor(self): + tensor_proto = tensor_pb2.TensorProto( + dtype=types_pb2.DataType.Value("DT_RESOURCE"), + tensor_shape=tensor_shape_pb2.TensorShapeProto( + dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])) + out = tensor_format.format_tensor( + debug_data.InconvertibleTensorProto(tensor_proto), "a") + + self.assertEqual(["Tensor \"a\":", ""], out.lines[:2]) + self.assertEqual(str(tensor_proto).split("\n"), out.lines[2:]) def testLocateTensorElement1DNoEllipsis(self): a = np.zeros(20) @@ -821,9 +843,15 @@ class RichTextLinesTest(test_util.TensorFlowTestCase): self.assertEqual([12, None], end_cols) def testLocateTensorElementAnnotationsUnavailable(self): - out = tensor_format.format_tensor(None, "a") + tensor_proto = tensor_pb2.TensorProto( + dtype=types_pb2.DataType.Value("DT_FLOAT"), + tensor_shape=tensor_shape_pb2.TensorShapeProto( + dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])) + out = tensor_format.format_tensor( + debug_data.InconvertibleTensorProto(tensor_proto, False), "a") - self.assertEqual(["Tensor \"a\":", "", "Uninitialized tensor"], out.lines) + self.assertEqual(["Tensor \"a\":", "", "Uninitialized tensor:"], + out.lines[:3]) with self.assertRaisesRegexp( AttributeError, "tensor_metadata is not available in annotations"): diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py index 2b11039215..ce4bc82e0a 100644 --- a/tensorflow/python/debug/lib/debug_data.py +++ b/tensorflow/python/debug/lib/debug_data.py @@ -39,6 +39,30 @@ FETCHES_INFO_FILE_TAG = "fetches_info_" FEED_KEYS_INFO_FILE_TAG = "feed_keys_info_" +class InconvertibleTensorProto(object): + """Represents a TensorProto that cannot be converted to np.ndarray.""" + + def __init__(self, tensor_proto, initialized=True): + """Constructor. + + Args: + tensor_proto: the `TensorProto` object that cannot be represented as a + `np.ndarray` object. + initialized: (`bool`) whether the Tensor is initialized. + """ + self._tensor_proto = tensor_proto + self._initialized = initialized + + def __str__(self): + output = "" if self._initialized else "Uninitialized tensor:\n" + output += str(self._tensor_proto) + return output + + @property + def initialized(self): + return self._initialized + + def load_tensor_from_event_file(event_file_path): """Load a tensor from an event file. @@ -69,26 +93,27 @@ def load_tensor_from_event(event): summary.value[0] field. Returns: - The tensor value loaded from the event file, as a `numpy.ndarray`. For - uninitialized Tensors, returns `None`. For Tensors of data types that - cannot be converted to `numpy.ndarray` (e.g., `tf.resource`), return - `None`. + The tensor value loaded from the event file, as a `numpy.ndarray`, if + representation of the tensor value by a `numpy.ndarray` is possible. + For uninitialized Tensors, returns `None`. For Tensors of data types that + cannot be represented as `numpy.ndarray` (e.g., `tf.resource`), return + the `TensorProto` protobuf object without converting it to a + `numpy.ndarray`. """ - if (event.summary.value[0].tensor.tensor_content or - event.summary.value[0].tensor.string_val): + tensor_proto = event.summary.value[0].tensor + if tensor_proto.tensor_content or tensor_proto.string_val: # Initialized tensor. - tensor_proto = event.summary.value[0].tensor if tensor_proto.dtype == types_pb2.DT_RESOURCE: - return None + tensor_value = InconvertibleTensorProto(tensor_proto) else: try: tensor_value = tensor_util.MakeNdarray(tensor_proto) except KeyError: - tensor_value = None + tensor_value = InconvertibleTensorProto(tensor_proto) else: # Uninitialized tensor or tensor of unconvertible data type. - tensor_value = None + tensor_value = InconvertibleTensorProto(tensor_proto, False) return tensor_value @@ -290,8 +315,10 @@ def has_inf_or_nan(datum, tensor): _ = datum # Datum metadata is unused in this predicate. - if tensor is None: + if isinstance(tensor, InconvertibleTensorProto): # Uninitialized tensor doesn't have bad numerical values. + # Also return False for data types that cannot be represented as numpy + # arrays. return False elif (np.issubdtype(tensor.dtype, np.float) or np.issubdtype(tensor.dtype, np.complex) or diff --git a/tensorflow/python/debug/lib/debug_data_test.py b/tensorflow/python/debug/lib/debug_data_test.py index 2f6822331d..dc45e8df6c 100644 --- a/tensorflow/python/debug/lib/debug_data_test.py +++ b/tensorflow/python/debug/lib/debug_data_test.py @@ -23,6 +23,7 @@ import tempfile import numpy as np +from tensorflow.core.framework import tensor_pb2 from tensorflow.python.debug.lib import debug_data from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest @@ -125,9 +126,15 @@ class HasNanOrInfTest(test_util.TensorFlowTestCase): a = np.array([]) self.assertFalse(debug_data.has_inf_or_nan(self._dummy_datum, a)) - def testNone(self): - a = None - self.assertFalse(debug_data.has_inf_or_nan(self._dummy_datum, a)) + def testInconvertibleTensorProto(self): + self.assertFalse(debug_data.has_inf_or_nan( + self._dummy_datum, + debug_data.InconvertibleTensorProto(tensor_pb2.TensorProto(), + initialized=False))) + self.assertFalse(debug_data.has_inf_or_nan( + self._dummy_datum, + debug_data.InconvertibleTensorProto(tensor_pb2.TensorProto(), + initialized=True))) def testDTypeComplexWorks(self): a = np.array([1j, 3j, 3j, 7j], dtype=np.complex128) diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py index acebb2297f..deb8249343 100644 --- a/tensorflow/python/debug/lib/session_debug_testlib.py +++ b/tensorflow/python/debug/lib/session_debug_testlib.py @@ -358,9 +358,11 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase): u_vals = dump.get_tensors(u_name, 0, "DebugIdentity") s_vals = dump.get_tensors(s_name, 0, "DebugIdentity") self.assertEqual(1, len(u_vals)) - self.assertIsNone(u_vals[0]) + self.assertIsInstance(u_vals[0], debug_data.InconvertibleTensorProto) + self.assertFalse(u_vals[0].initialized) self.assertEqual(1, len(s_vals)) - self.assertIsNone(s_vals[0]) + self.assertIsInstance(s_vals[0], debug_data.InconvertibleTensorProto) + self.assertFalse(s_vals[0].initialized) # Call run() again, to check that u is initialized properly. self.assertAllClose(u_init_val, sess.run(u)) @@ -1422,7 +1424,10 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase): self._dump_root, partition_graphs=run_metadata.partition_graphs) self.assertTrue(dump.loaded_partition_graphs()) - self.assertIsNone(dump.get_tensors("fifo_queue", 0, "DebugIdentity")[0]) + fifo_queue_tensor = dump.get_tensors("fifo_queue", 0, "DebugIdentity")[0] + self.assertIsInstance(fifo_queue_tensor, + debug_data.InconvertibleTensorProto) + self.assertTrue(fifo_queue_tensor.initialized) self.assertAllClose( [101.0, 202.0, 303.0], dump.get_tensors("enqueue_many/component_0", 0, "DebugIdentity")[0]) -- GitLab From 0f30fd374f0ad00bfeca434f544cc52f6f559499 Mon Sep 17 00:00:00 2001 From: Jonathan Hseu Date: Wed, 26 Apr 2017 12:45:41 -0800 Subject: [PATCH 015/697] Release the Python GIL around all file_io operations. Change: 154338968 --- tensorflow/python/lib/io/file_io.i | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tensorflow/python/lib/io/file_io.i b/tensorflow/python/lib/io/file_io.i index c0c4e035fc..a6fe802597 100644 --- a/tensorflow/python/lib/io/file_io.i +++ b/tensorflow/python/lib/io/file_io.i @@ -31,6 +31,13 @@ limitations under the License. #include "tensorflow/core/protobuf/meta_graph.pb.h" %} +// Release the Python GIL for the duration of all methods. +%exception { + Py_BEGIN_ALLOW_THREADS; + $action + Py_END_ALLOW_THREADS; +} + %{ inline void FileExists(const string& filename, TF_Status* out_status) { tensorflow::Status status = tensorflow::Env::Default()->FileExists(filename); @@ -299,3 +306,6 @@ string ReadFromStream(tensorflow::io::BufferedInputStream* stream, %include "tensorflow/core/lib/io/path.h" %include "tensorflow/core/platform/file_statistics.h" + +// Delete the previously defined default handler that releases the Python GIL. +%noexception; -- GitLab From b82cb8e93245b0de66794f8986db453d022ae341 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Wed, 26 Apr 2017 13:15:06 -0800 Subject: [PATCH 016/697] Moving _read_keyed_batch_{features, examples}_shared_queue from private to public. Change: 154343093 --- .../learn/python/learn/learn_io/__init__.py | 4 +- .../learn/python/learn/learn_io/graph_io.py | 48 +++++++++---------- .../python/learn/learn_io/graph_io_test.py | 7 ++- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py index 4567928358..06c3782a47 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py @@ -21,14 +21,14 @@ from __future__ import print_function from tensorflow.contrib.learn.python.learn.learn_io.dask_io import extract_dask_data from tensorflow.contrib.learn.python.learn.learn_io.dask_io import extract_dask_labels from tensorflow.contrib.learn.python.learn.learn_io.dask_io import HAS_DASK -from tensorflow.contrib.learn.python.learn.learn_io.graph_io import _read_keyed_batch_examples_shared_queue -from tensorflow.contrib.learn.python.learn.learn_io.graph_io import _read_keyed_batch_features_shared_queue from tensorflow.contrib.learn.python.learn.learn_io.graph_io import queue_parsed_features from tensorflow.contrib.learn.python.learn.learn_io.graph_io import read_batch_examples from tensorflow.contrib.learn.python.learn.learn_io.graph_io import read_batch_features from tensorflow.contrib.learn.python.learn.learn_io.graph_io import read_batch_record_features from tensorflow.contrib.learn.python.learn.learn_io.graph_io import read_keyed_batch_examples +from tensorflow.contrib.learn.python.learn.learn_io.graph_io import read_keyed_batch_examples_shared_queue from tensorflow.contrib.learn.python.learn.learn_io.graph_io import read_keyed_batch_features +from tensorflow.contrib.learn.python.learn.learn_io.graph_io import read_keyed_batch_features_shared_queue from tensorflow.contrib.learn.python.learn.learn_io.numpy_io import numpy_input_fn from tensorflow.contrib.learn.python.learn.learn_io.pandas_io import extract_pandas_data from tensorflow.contrib.learn.python.learn.learn_io.pandas_io import extract_pandas_labels diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 9bdd3206b2..6b552f59d0 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -174,17 +174,17 @@ def read_keyed_batch_examples(file_pattern, seed=seed) -def _read_keyed_batch_examples_shared_queue(file_pattern, - batch_size, - reader, - randomize_input=True, - num_epochs=None, - queue_capacity=10000, - num_threads=1, - read_batch_size=1, - parse_fn=None, - name=None, - seed=None): +def read_keyed_batch_examples_shared_queue(file_pattern, + batch_size, + reader, + randomize_input=True, + num_epochs=None, + queue_capacity=10000, + num_threads=1, + read_batch_size=1, + parse_fn=None, + name=None, + seed=None): """Adds operations to read, queue, batch `Example` protos. Given file pattern (or list of files), will setup a shared queue for file @@ -512,18 +512,18 @@ def read_keyed_batch_features(file_pattern, name=scope) -def _read_keyed_batch_features_shared_queue(file_pattern, - batch_size, - features, - reader, - randomize_input=True, - num_epochs=None, - queue_capacity=10000, - reader_num_threads=1, - feature_queue_capacity=100, - num_queue_runners=2, - parse_fn=None, - name=None): +def read_keyed_batch_features_shared_queue(file_pattern, + batch_size, + features, + reader, + randomize_input=True, + num_epochs=None, + queue_capacity=10000, + reader_num_threads=1, + feature_queue_capacity=100, + num_queue_runners=2, + parse_fn=None, + name=None): """Adds operations to read, queue, batch and parse `Example` protos. Given file pattern (or list of files), will setup a shared queue for file @@ -571,7 +571,7 @@ def _read_keyed_batch_features_shared_queue(file_pattern, """ with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope: - keys, examples = _read_keyed_batch_examples_shared_queue( + keys, examples = read_keyed_batch_examples_shared_queue( file_pattern, batch_size, reader, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py index 542aaabc95..f25f7caf61 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py @@ -26,7 +26,6 @@ import tempfile from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.learn.python.learn.learn_io import graph_io -from tensorflow.contrib.learn.python.learn.learn_io.graph_io import _read_keyed_batch_examples_shared_queue from tensorflow.python.client import session as session_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib @@ -464,7 +463,7 @@ class GraphIOTest(test.TestCase): name = "my_batch" with ops.Graph().as_default() as g, self.test_session(graph=g) as session: - keys, inputs = _read_keyed_batch_examples_shared_queue( + keys, inputs = graph_io.read_keyed_batch_examples_shared_queue( filenames, batch_size, reader=io_ops.TextLineReader, @@ -528,7 +527,7 @@ class GraphIOTest(test.TestCase): with ops.Graph().as_default() as g1, session_lib.Session( server.target, graph=g1) as session: - keys, inputs = _read_keyed_batch_examples_shared_queue( + keys, inputs = graph_io.read_keyed_batch_examples_shared_queue( filenames, batch_size, reader=io_ops.TextLineReader, @@ -557,7 +556,7 @@ class GraphIOTest(test.TestCase): with ops.Graph().as_default() as g2, session_lib.Session( server.target, graph=g2) as session: - keys, inputs = _read_keyed_batch_examples_shared_queue( + keys, inputs = graph_io.read_keyed_batch_examples_shared_queue( filenames, batch_size, reader=io_ops.TextLineReader, -- GitLab From 0ad55c0ffdb3a2c86881e791d34fbdf1aacb359f Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Wed, 26 Apr 2017 13:19:33 -0800 Subject: [PATCH 017/697] [XLA] Run transpose_folding on nested computations We only ran the pass on the entry computation which would make us lose out on optimization opportunities. Visit all computations to find any potential transpose folding opportunities. Change: 154343660 --- .../xla/service/gpu/ir_emission_utils.cc | 10 ++++ .../compiler/xla/service/transpose_folding.cc | 23 +++++---- .../xla/service/transpose_folding_test.cc | 48 +++++++++++++++++-- 3 files changed, 66 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index e8378a7f44..c6e8a2f78b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -59,6 +59,11 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { + // We can only do this if the HLO is unnested. + if (hlo.parent() != hlo.GetModule()->entry_computation()) { + return false; + } + // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { const Shape& lhs_shape = hlo.operand(0)->shape(); @@ -85,6 +90,11 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { } bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { + // We can only do this if the HLO is unnested. + if (hlo.parent() != hlo.GetModule()->entry_computation()) { + return false; + } + // Forward convolution. if (hlo.opcode() == HloOpcode::kConvolution) { const ConvolutionDimensionNumbers& dnums = diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index cfb90e6e1d..a0c88c6bbc 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -76,8 +76,7 @@ using InstructionOperandsPair = // the parent HLO computation of `dot`. // // Returns whether the module is changed. -bool FoldTransposeIntoDot(InstructionOperandsPair pair, - HloComputation* computation) { +bool FoldTransposeIntoDot(InstructionOperandsPair pair) { auto* dot = pair.first; std::vector instructions_to_fuse(1, dot); for (const int64 operand_index : pair.second) { @@ -89,7 +88,7 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair, return false; } - computation->CreateFusionInstruction( + dot->parent()->CreateFusionInstruction( instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot); return true; } @@ -98,8 +97,7 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair, // `computation` is the parent HLO computation of `convolution`. // // Returns whether the module is changed. -bool FoldTransposeIntoConvolution(InstructionOperandsPair pair, - HloComputation* computation) { +bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto& convolution = *pair.first; // We only support fusing the RHS transpose into convolution. @@ -135,8 +133,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair, auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), convolution.mutable_operand(0), &transpose_operand, convolution.window(), new_dnums); - TF_CHECK_OK(computation->ReplaceWithNewInstruction(&convolution, - std::move(new_conv))); + TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( + &convolution, std::move(new_conv))); return true; } @@ -152,8 +150,6 @@ TransposeFolding::TransposeFolding( StatusOr TransposeFolding::Run(HloModule* module) { // Modifying the graph while traversing is dangerous, so we find all folding // opportunities before actually folding them. - HloComputation* entry_computation = module->entry_computation(); - std::vector> foldable_dots; std::vector> foldable_convolutions; auto visit_fn = [this, &foldable_dots, @@ -175,14 +171,17 @@ StatusOr TransposeFolding::Run(HloModule* module) { } return tensorflow::Status::OK(); }; - TF_RETURN_IF_ERROR(entry_computation->root_instruction()->Accept(visit_fn)); + + for (auto& comp : module->computations()) { + TF_RETURN_IF_ERROR(comp->Accept(visit_fn)); + } bool changed = false; for (InstructionOperandsPair& pair : foldable_dots) { - changed |= FoldTransposeIntoDot(pair, entry_computation); + changed |= FoldTransposeIntoDot(pair); } for (InstructionOperandsPair& pair : foldable_convolutions) { - changed |= FoldTransposeIntoConvolution(pair, entry_computation); + changed |= FoldTransposeIntoConvolution(pair); } return changed; } diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 6643f541da..c72d127ea8 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -41,9 +41,7 @@ class TransposeFoldingTest : public ::testing::Test { TransposeFolding transpose_folding( [](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { - return gpu::ImplementedAsGemm(dot) - ? candidate_operands - : TransposeFolding::OperandIndices{}; + return candidate_operands; }, [](const HloInstruction& convolution, const TransposeFolding::OperandIndices& candidate_operands) { @@ -159,6 +157,50 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { EXPECT_EQ(6, callee_computation->instructions().size()); } +TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), + /*name=*/"y")); + HloInstruction* transpose_y = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, + /*lhs=*/x, /*rhs=*/transpose_y)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(dot)); + + HloInstruction* call = module.OutlineExpressionFromComputation( + {transpose_y, dot}, "outlined", entry_computation); + + FoldTranspose(&module); + + // Instructions after folding: x, y, and the fusion. + std::unordered_set instruction_set; + for (auto& instruction : entry_computation->instructions()) { + instruction_set.insert(instruction.get()); + } + CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(call)) + << "call is not in entry_computation."; + CHECK(instruction_set.empty()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* fusion = + call->called_computations().front()->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); + + // The fusion instruction should contain two parameters, one transpose and + // one dot. + EXPECT_EQ(4, fusion->fused_instructions().size()); +} + // Test that a two dimension swap of the kernel gets folded into convolution. TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { auto builder = HloComputation::Builder("entry_computation"); -- GitLab From 028e19ace26b676e7d136b98639296a1160110ff Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Wed, 26 Apr 2017 13:33:29 -0800 Subject: [PATCH 018/697] Add an option for shape inference requirement to C++ Shape Refiner. Change: 154345494 --- .../core/common_runtime/shape_refiner.cc | 16 +++++++++++--- .../core/common_runtime/shape_refiner.h | 6 ++++++ .../core/common_runtime/shape_refiner_test.cc | 21 +++++++++++++++++++ 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index f2dff0bf75..5135355a94 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/stl_util.h" @@ -79,7 +80,8 @@ Status ShapeRefiner::AddNode(const Node* node) { // Get the shape function for this node const OpRegistrationData* op_reg_data; TF_RETURN_IF_ERROR(ops_registry_->LookUp(node->type_string(), &op_reg_data)); - if (op_reg_data->shape_inference_fn == nullptr) { + if (op_reg_data->shape_inference_fn == nullptr && + require_shape_inference_fns_) { return errors::InvalidArgument( "No shape inference function exists for op '", node->type_string(), "', did you forget to define it?"); @@ -102,7 +104,11 @@ Status ShapeRefiner::AddNode(const Node* node) { } // Run the shape inference function, and return if there was an error. - TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn)); + if (op_reg_data->shape_inference_fn) { + TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn)); + } else { + TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape)); + } // We must run the shape function repeatedly, in case users write // shape functions where they only conditionally call input_tensor() @@ -163,7 +169,11 @@ Status ShapeRefiner::AddNode(const Node* node) { // so re-run shape inference. c->set_input_tensors(input_tensors); c->set_input_tensors_as_shapes(input_tensors_as_shapes); - TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c.get())); + if (op_reg_data->shape_inference_fn) { + TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(c.get())); + } else { + TF_RETURN_IF_ERROR(shape_inference::UnknownShape(c.get())); + } } } while (rerun_shape_fn); diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index f23f9361eb..2d04ea1505 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -68,6 +68,10 @@ class ShapeRefiner { int32 graph_def_version() { return graph_def_version_; } void set_graph_def_version(int32 version) { graph_def_version_ = version; } + void set_require_shape_inference_fns(bool require_shape_inference_fns) { + require_shape_inference_fns_ = require_shape_inference_fns; + } + private: // Extracts the subgraph ending at 'node' that is statically // computable and inserts into 'out_graph'. If statically computable, @@ -129,6 +133,8 @@ class ShapeRefiner { static constexpr int64 kMaxTensorSize = 1024; std::unordered_map const_tensor_map_; + bool require_shape_inference_fns_ = true; + TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner); }; diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index 05274ff311..d7e7c3b5ad 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -126,6 +126,27 @@ TEST(ShapeRefinerTest, SetShape) { ASSERT_FALSE(m.SetShape(a.node(), 0, h).ok()); } +namespace { + +// An op with no shape function. +REGISTER_OP("TestOpWithNoShapeFn").Input("a: int32").Output("o: int32"); + +} // namespace + +TEST(ShapeRefinerTest, MissingShapeInferenceFns) { + Scope root = Scope::NewRootScope(); + auto a = ops::Const(root, 42); + Node* b; + TF_ASSERT_OK(NodeBuilder("b", "TestOpWithNoShapeFn") + .Input(a.node()) + .Finalize(root.graph(), &b)); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); + TF_ASSERT_OK(m.AddNode(a.node())); + EXPECT_FALSE(m.AddNode(b).ok()); + m.set_require_shape_inference_fns(false); + TF_EXPECT_OK(m.AddNode(b)); +} + TEST(ShapeRefinerTest, PropagateConstants) { // Reduction dimension is a variable, so we don't know its value. // So the output shape value is unknown (though its rank is known). -- GitLab From 5c2c39d19e64634dc9a3d33fec183cf324289355 Mon Sep 17 00:00:00 2001 From: James Qin Date: Wed, 26 Apr 2017 13:46:31 -0800 Subject: [PATCH 019/697] Fix two tiny bugs in cudnn_rnn_ops 1. test cudnn_rnn_ops_test.py:testSaveRestore was broken. 2. a tiny bug that hardcoded is_training in GRU model. Change: 154347435 --- .../python/kernel_tests/cudnn_rnn_ops_test.py | 5 ++--- .../contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py | 10 +++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 945791578a..b6047c531c 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -58,9 +58,8 @@ class CudnnRNNTest(TensorFlowTestCase): params: a Variable for weight and bias parameters. model: a CudnnRNN model. """ - params_saveable = cudnn_rnn_ops.RNNParamsSaveable(model.params_to_canonical, - model.canonical_to_params, - params) + params_saveable = cudnn_rnn_ops.RNNParamsSaveable( + model.params_to_canonical, model.canonical_to_params, [params]) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable) def _testSaveRestoreVariable(self, rnn_mode): diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 9ab337df15..c23d4cd4e3 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -48,8 +48,8 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject): def __init__(self, params_to_canonical, canonical_to_params, - name="params_canonical", - *param_variables): + param_variables, + name="params_canonical"): """Creates a RNNParamsSaveable object. RNNParamsSaveable is saveable/restorable in a checkpoint file and is used @@ -83,11 +83,11 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject): must return a scalar (e.g. in the case of cuDNN) or a tuple. This function could be _CudnnRNN.canonical_to_params() or a user-defined function. - name: the name of the RNNParamsSaveable object. - *param_variables: a list of Variables for parameters in a specific form. + param_variables: a list of Variables for parameters in a specific form. For cuDNN RNN ops, this is a single merged variable for both weights and biases; for other RNN ops, this might be multiple unmerged or partially merged variables respectively for weights and biases. + name: the name of the RNNParamsSaveable object. """ # There is only a single merged parameter variable for cuDNN when saving. weights, biases = params_to_canonical(param_variables[0]) @@ -411,7 +411,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): output_h: the final state for h. """ output, output_h, _ = super(_CudnnRNNNoInputC, self).__call__( - input_data, input_h, None, params, is_training=True) + input_data, input_h, None, params, is_training=is_training) return (output, output_h) -- GitLab From 341141a48bea3e39c76ed36e9015bd9ab67a6463 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Apr 2017 13:56:44 -0800 Subject: [PATCH 020/697] [XLA] Eliminate empty operands of concat operations at the HLO level Change: 154348912 --- .../xla/service/algebraic_simplifier.cc | 23 +++++++ .../xla/service/algebraic_simplifier_test.cc | 62 +++++++++++++++++++ tensorflow/compiler/xla/tests/concat_test.cc | 9 +++ 3 files changed, 94 insertions(+) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 6acb9bdcba..ee265c6688 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -307,6 +307,29 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( // Unary concatenates are useless. if (operands.size() == 1) { ReplaceInstructionIfSameShape(concatenate, operands[0]); + return Status::OK(); + } + // Filter out and remove empty operands. + std::vector nonempty_operands; + for (HloInstruction* operand : operands) { + if (!ShapeUtil::HasZeroElements(operand->shape())) { + nonempty_operands.push_back(operand); + } + } + if (nonempty_operands.size() < operands.size()) { + HloInstruction* replacement; + if (nonempty_operands.empty()) { + replacement = operands[0]; + } else if (nonempty_operands.size() == 1) { + replacement = nonempty_operands[0]; + } else { + replacement = + computation_->AddInstruction(concatenate->CloneWithNewOperands( + concatenate->shape(), nonempty_operands)); + } + VLOG(10) << "trying to replace " << concatenate->ToString() << " with " + << replacement->ToString(); + ReplaceInstructionIfSameShape(concatenate, replacement); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 3123ee4f87..77b8fca1a9 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -531,6 +531,68 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { EXPECT_THAT(computation->root_instruction(), param0); } +// Test that empty operands of concatenates are removed. +TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { + const int kParamLength = 100; + Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r1f32, "param1")); + HloInstruction* empty_literal = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction* empty_slice = + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42})); + Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength}); + builder.AddInstruction(HloInstruction::CreateConcatenate( + result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT( + computation->root_instruction(), + op::Concatenate(empty_literal, param0, param0, empty_slice, param1)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Concatenate(param0, param0, param1)); +} + +// Test a concatenate with only empty operands is removed. +TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { + const int kParamLength = 100; + Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* empty_literal = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction* empty_slice = + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42})); + Shape result_shape = ShapeUtil::MakeShape(F32, {0}); + builder.AddInstruction(HloInstruction::CreateConcatenate( + result_shape, {empty_literal, empty_slice}, 0)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Concatenate(empty_literal, empty_slice)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(computation->root_instruction(), empty_literal); +} + // Test that a simplification which changes layouts is not performed if layout // sensitive is true. TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index e645e23361..63bfac441d 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -57,6 +57,15 @@ XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } +XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto concatenated = builder.ConcatInDim({a}, 0); + + std::vector expected = {}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + // Show that we can't concatenate R0 with R0 because we can't name the dimension // to concatenate on. XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { -- GitLab From a62d489db24365435ae046dbbb4e6616c16d283e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dandelion=20Man=C3=A9?= Date: Wed, 26 Apr 2017 14:12:03 -0800 Subject: [PATCH 021/697] Migrate tf_dashboard_common to d3v4. (No externally visible change, yet.) Change: 154351373 --- .../tf_color_scale_d3v4/colorScale.ts | 2 - .../tf_color_scale_d3v4/tf-color-scale.html | 3 - .../dashboard-behavior.ts | 38 ++ .../dashboard-style.html | 53 +++ .../reload-behavior.ts | 39 ++ .../run-color-style.html | 79 ++++ .../scrollbar-style.html | 46 +++ .../tensorboard-color.html | 32 ++ .../tf_dashboard_common_d3v4/tests.html | 31 ++ .../tf-categorizer-demo.html | 108 ++++++ .../tf-categorizer-tests.ts | 143 +++++++ .../tf-categorizer.html | 59 +++ .../tf-categorizer.ts | 183 +++++++++ .../tf-chart-scaffold.html | 150 ++++++++ .../tf-collapsable-pane-demo.html | 34 ++ .../tf-collapsable-pane.html | 109 ++++++ .../tf-dashboard-layout.html | 67 ++++ .../tf-dashboard.html | 24 ++ .../tf-downloader.html | 99 +++++ .../tf-multi-checkbox-demo.html | 177 +++++++++ .../tf-multi-checkbox.html | 157 ++++++++ .../tf-multi-checkbox.ts | 206 ++++++++++ .../tf-no-data-warning.html | 129 +++++++ .../tf-option-selector.html | 94 +++++ .../tf-panes-helper.html | 352 ++++++++++++++++++ .../tf-regex-group-demo.html | 46 +++ .../tf-regex-group.html | 97 +++++ .../tf-regex-group.ts | 84 +++++ .../tf-run-selector.html | 188 ++++++++++ .../tf-sidebar-helper.html | 164 ++++++++ 30 files changed, 2988 insertions(+), 5 deletions(-) create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/dashboard-behavior.ts create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/dashboard-style.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/reload-behavior.ts create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/run-color-style.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/scrollbar-style.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tensorboard-color.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tests.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer-demo.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer-tests.ts create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer.ts create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-chart-scaffold.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-collapsable-pane-demo.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-collapsable-pane.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-dashboard-layout.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-dashboard.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-downloader.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-multi-checkbox-demo.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-multi-checkbox.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-multi-checkbox.ts create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-no-data-warning.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-option-selector.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-panes-helper.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group-demo.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group.ts create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-run-selector.html create mode 100644 tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-sidebar-helper.html diff --git a/tensorflow/tensorboard/components/tf_color_scale_d3v4/colorScale.ts b/tensorflow/tensorboard/components/tf_color_scale_d3v4/colorScale.ts index d10e753445..ff90d46aa2 100644 --- a/tensorflow/tensorboard/components/tf_color_scale_d3v4/colorScale.ts +++ b/tensorflow/tensorboard/components/tf_color_scale_d3v4/colorScale.ts @@ -63,8 +63,6 @@ export class ColorScale { } } - - Polymer({ is: 'tf-color-scale', properties: { diff --git a/tensorflow/tensorboard/components/tf_color_scale_d3v4/tf-color-scale.html b/tensorflow/tensorboard/components/tf_color_scale_d3v4/tf-color-scale.html index 7d2cb8bafd..e3ef6cf630 100644 --- a/tensorflow/tensorboard/components/tf_color_scale_d3v4/tf-color-scale.html +++ b/tensorflow/tensorboard/components/tf_color_scale_d3v4/tf-color-scale.html @@ -25,7 +25,4 @@ a set of colors. @element tf-color-scale --> - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/dashboard-behavior.ts b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/dashboard-behavior.ts new file mode 100644 index 0000000000..3e40da1452 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/dashboard-behavior.ts @@ -0,0 +1,38 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** + * A behavior that TensorBoard dashboards must implement. This behavior serves + * the purpose of an interface. + */ +export function DashboardBehavior(dashboardName) { + return { + properties: { + name: { + type: String, + value: dashboardName, + readOnly: true, + }, + }, + // This method is called when the dashboard reloads, either when the + // dashboard is first visited, periodically reloaded, or manually reloaded + // via the user clicking the button. Note that dashboard custom elements + // that use TF.Dashboard.ReloadBehavior already implement a reload method. + reload() { + throw Error( + 'The ' + dashboardName + ' dashboard does not implement reload.'); + }, + }; +} diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/dashboard-style.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/dashboard-style.html new file mode 100644 index 0000000000..1eccc020eb --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/dashboard-style.html @@ -0,0 +1,53 @@ + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/reload-behavior.ts b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/reload-behavior.ts new file mode 100644 index 0000000000..8b5ca120d6 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/reload-behavior.ts @@ -0,0 +1,39 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** + * ReloadBehavior: A simple behavior for dashboards where the + * frontendReload() function should find every child element with a + * given tag name (e.g. "tf-line-chart" or "tf-image-loader") + * and call a `reload` method on that child. + * May later extend it so it has more sophisticated logic, e.g. reloading + * only tags that are in view. + */ +export function ReloadBehavior(tagName) { + return { + properties: { + reloadTag: { + type: String, + value: tagName, + }, + }, + frontendReload: function() { + var elements = this.getElementsByTagName(this.reloadTag); + Array.prototype.forEach.call(elements, function(x) { + x.reload(); + }); + }, + }; +} diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/run-color-style.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/run-color-style.html new file mode 100644 index 0000000000..ff4cfacc91 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/run-color-style.html @@ -0,0 +1,79 @@ + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/scrollbar-style.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/scrollbar-style.html new file mode 100644 index 0000000000..b345781e2e --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/scrollbar-style.html @@ -0,0 +1,46 @@ + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tensorboard-color.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tensorboard-color.html new file mode 100644 index 0000000000..4d95351edb --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tensorboard-color.html @@ -0,0 +1,32 @@ + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tests.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tests.html new file mode 100644 index 0000000000..270a2522b9 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tests.html @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer-demo.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer-demo.html new file mode 100644 index 0000000000..6692962763 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer-demo.html @@ -0,0 +1,108 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer-tests.ts b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer-tests.ts new file mode 100644 index 0000000000..bf70d9d4bb --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer-tests.ts @@ -0,0 +1,143 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {topLevelNamespaceCategorizer, Category, Categorizer, defineCategory, _categorizer} from './tf-categorizer' + +let assert = chai.assert; + +describe('categorizer', () => { + describe('topLevelNamespaceCategorizer', () => { + it('returns empty array on empty tags', () => { + assert.lengthOf(topLevelNamespaceCategorizer([]), 0); + }); + + it('handles a simple case', () => { + let simple = [ + 'foo1/bar', 'foo1/zod', 'foo2/bar', 'foo2/zod', 'gosh/lod/mar', + 'gosh/lod/ned' + ]; + let expected = [ + {name: 'foo1', tags: ['foo1/bar', 'foo1/zod']}, + {name: 'foo2', tags: ['foo2/bar', 'foo2/zod']}, + {name: 'gosh', tags: ['gosh/lod/mar', 'gosh/lod/ned']}, + ]; + assert.deepEqual(topLevelNamespaceCategorizer(simple), expected); + }); + + it('orders the categories', () => { + let test = ['e', 'f', 'g', 'a', 'b', 'c']; + let expected = [ + {name: 'a', tags: ['a']}, + {name: 'b', tags: ['b']}, + {name: 'c', tags: ['c']}, + {name: 'e', tags: ['e']}, + {name: 'f', tags: ['f']}, + {name: 'g', tags: ['g']}, + ]; + assert.deepEqual(topLevelNamespaceCategorizer(test), expected); + }); + + it('handles cases where category names overlap node names', () => { + let test = ['a', 'a/a', 'a/b', 'a/c', 'b', 'b/a']; + let actual = topLevelNamespaceCategorizer(test); + let expected = [ + {name: 'a', tags: ['a', 'a/a', 'a/b', 'a/c']}, + {name: 'b', tags: ['b', 'b/a']}, + ]; + assert.deepEqual(actual, expected); + }); + + it('handles singleton case', () => { + assert.deepEqual( + topLevelNamespaceCategorizer(['a']), [{name: 'a', tags: ['a']}]); + }); + }); + + describe('customCategorizer', () => { + function noFallbackCategorizer(tags: string[]): Category[] { + return []; + } + + function testCategorizer( + defs: string[], fallback: Categorizer, tags: string[]): Category[] { + let catDefs = defs.map(defineCategory); + return _categorizer(catDefs, fallback)(tags); + } + + it('categorizes by regular expression', () => { + let defs = ['foo..', 'bar..']; + let tags = ['fooab', 'fooxa', 'barts', 'barms']; + let actual = testCategorizer(defs, noFallbackCategorizer, tags); + let expected = [ + {name: 'foo..', tags: ['fooab', 'fooxa']}, + {name: 'bar..', tags: ['barms', 'barts']}, + ]; + assert.deepEqual(actual, expected); + }); + + it('matches non-exclusively', () => { + let tags = ['abc', 'bar', 'zod']; + let actual = + testCategorizer(['...', 'bar'], noFallbackCategorizer, tags); + let expected = [ + {name: '...', tags: ['abc', 'bar', 'zod']}, + {name: 'bar', tags: ['bar']}, + ]; + assert.deepEqual(actual, expected); + }); + + it('creates categories for unmatched rules', () => { + let actual = + testCategorizer(['a', 'b', 'c'], noFallbackCategorizer, []); + let expected = [ + {name: 'a', tags: []}, + {name: 'b', tags: []}, + {name: 'c', tags: []}, + ]; + assert.deepEqual(actual, expected); + }); + + it('category regexs work with special characters', () => { + let defs = ['^\\w+$', '^\\d+$', '^\\/..$']; + let tags = ['foo', '3243', '/xa']; + let actual = testCategorizer(defs, noFallbackCategorizer, tags); + let expected = [ + {name: '^\\w+$', tags: ['3243', 'foo']}, + {name: '^\\d+$', tags: ['3243']}, + {name: '^\\/..$', tags: ['/xa']}, + ]; + assert.deepEqual(actual, expected); + }); + + it('category tags are sorted', () => { + let tags = ['a', 'z', 'c', 'd', 'e', 'x', 'f', 'y', 'g']; + let sorted = tags.slice().sort(); + let expected = [{name: '.*', tags: sorted}]; + let actual = testCategorizer(['.*'], noFallbackCategorizer, tags); + assert.deepEqual(actual, expected); + }); + + it('if nonexclusive: all tags passed to fallback', () => { + let passedToDefault = null; + function defaultCategorizer(tags: string[]): Category[] { + passedToDefault = tags; + return []; + } + let tags = ['foo', 'bar', 'foo123']; + testCategorizer(['foo'], defaultCategorizer, tags); + assert.deepEqual(passedToDefault, tags); + }); + }); +}); diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer.html new file mode 100644 index 0000000000..0b563fa4bd --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer.html @@ -0,0 +1,59 @@ + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer.ts b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer.ts new file mode 100644 index 0000000000..10271d2c08 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-categorizer.ts @@ -0,0 +1,183 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import * as d3 from 'd3'; // from //third_party/javascript/typings/d3_v4 +import {compareTagNames} from '../vz_sorting_d3v4/sorting' + +/** + * This module contains methods that allow sorting tags into 'categories'. + * A category contains a name and a list of tags. + * The sorting strategy is defined by a 'CustomCategorization', which contains + * 'categoryDefinitions' which are regex rules used to construct a category. + * E.g. the regex rule 'xent' will create a category called 'xent' that + * contains values whose tags match the regex. + * + * After custom categories are evaluated, the tags are sorted by a hardcoded + * fallback categorizer, which may, for example, group tags into categories + * based on their top namespace. + */ + +export interface Category { + // Categories that data is sorted into + name: string; + tags: string[]; +} + +export interface CustomCategorization { + // Defines a categorization strategy + categoryDefinitions: string[]; + fallbackCategorizer: string; + /* {'TopLevelNamespaceCategorizer', + 'LegacyUnderscoreCategorizer'} */ +} + +export interface Categorizer { + // Function that generates categories + (tags: string[]): Category[]; +} + +/* Canonical TensorFlow ops are namespaced using forward slashes. + * This fallback categorizer categorizes by the top-level namespace. + */ +export var topLevelNamespaceCategorizer: Categorizer = splitCategorizer(/\//); + +export function fallbackCategorizer(s: string): Categorizer { + switch (s) { + case 'TopLevelNamespaceCategorizer': + return topLevelNamespaceCategorizer; + default: + throw new Error('Unrecognized categorization strategy: ' + s); + } +} + +/* An 'extractor' is a function that takes a tag name, and 'extracts' a + * category name. + * This function takes an extractor, and produces a categorizer. + * Currently, it is just used for the fallbackCategorizer, but we may want to + * refactor the general categorization logic to use the concept of extractors. + */ +function extractorToCategorizer(extractor: (s: string) => string): Categorizer { + return (tags: string[]): Category[] => { + if (tags.length === 0) { + return []; + } + let sortedTags = tags.slice().sort(compareTagNames); + let categories: Category[] = []; + let currentCategory = { + name: extractor(sortedTags[0]), + tags: [], + }; + sortedTags.forEach((t: string) => { + let topLevel = extractor(t); + if (currentCategory.name !== topLevel) { + categories.push(currentCategory); + currentCategory = { + name: topLevel, + tags: [], + }; + } + currentCategory.tags.push(t); + }); + categories.push(currentCategory); + return categories; + }; +} + +function splitCategorizer(r: RegExp): Categorizer { + let extractor = (t: string) => { + return t.split(r)[0]; + }; + return extractorToCategorizer(extractor); +} + +export interface CategoryDefinition { + name: string; + matches: (t: string) => boolean; +} + +export function defineCategory(ruledef: string): CategoryDefinition { + let r = new RegExp(ruledef); + let f = function(tag: string): boolean { + return r.test(tag); + }; + return {name: ruledef, matches: f}; +} + +export function _categorizer( + rules: CategoryDefinition[], fallback: Categorizer) { + return function(tags: string[]): Category[] { + let remaining: d3.Set = d3.set(tags); + let userSpecified = rules.map((def: CategoryDefinition) => { + let tags: string[] = []; + remaining.each((t: string) => { + if (def.matches(t)) { + tags.push(t); + } + }); + let cat = {name: def.name, tags: tags.sort(compareTagNames)}; + return cat; + }); + let defaultCategories = fallback(remaining.values()); + return userSpecified.concat(defaultCategories); + }; +} + +export function categorizer(s: CustomCategorization): Categorizer { + let rules = s.categoryDefinitions.map(defineCategory); + let fallback = fallbackCategorizer(s.fallbackCategorizer); + return _categorizer(rules, fallback); +}; + +Polymer({ + is: 'tf-categorizer', + properties: { + regexes: {type: Array}, + tags: {type: Array}, + categoriesAreExclusive: {type: Boolean, value: true}, + fallbackCategorizer: { + type: String, + value: 'TopLevelNamespaceCategorizer', + }, + categorizer: { + type: Object, + computed: + 'computeCategorization(regexes.*, categoriesAreExclusive, fallbackCategorizer)', + }, + categories: { + type: Array, + value: function() { + return []; + }, + notify: true, + readOnly: true + }, + }, + observers: ['recategorize(tags.*, categorizer)'], + computeCategorization: function( + regexes, categoriesAreExclusive, fallbackCategorizer) { + var categorizationStrategy = { + categoryDefinitions: regexes.base, + categoriesAreExclusive: categoriesAreExclusive, + fallbackCategorizer: fallbackCategorizer, + }; + return categorizer(categorizationStrategy); + }, + recategorize: function() { + this.debounce('tf-categorizer-recategorize', function() { + var categories = this.categorizer(this.tags); + this._setCategories(categories); + }) + }, +}); \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-chart-scaffold.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-chart-scaffold.html new file mode 100644 index 0000000000..402f909287 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-chart-scaffold.html @@ -0,0 +1,150 @@ + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-collapsable-pane-demo.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-collapsable-pane-demo.html new file mode 100644 index 0000000000..6c8bdb92ee --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-collapsable-pane-demo.html @@ -0,0 +1,34 @@ + + + + + + + + + + + +

This is content inside the pane.

+
+ + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-collapsable-pane.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-collapsable-pane.html new file mode 100644 index 0000000000..96b4337668 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-collapsable-pane.html @@ -0,0 +1,109 @@ + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-dashboard-layout.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-dashboard-layout.html new file mode 100644 index 0000000000..058b0d5946 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-dashboard-layout.html @@ -0,0 +1,67 @@ + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-dashboard.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-dashboard.html new file mode 100644 index 0000000000..039da39ecb --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-dashboard.html @@ -0,0 +1,24 @@ + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-downloader.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-downloader.html new file mode 100644 index 0000000000..dd87a47222 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-downloader.html @@ -0,0 +1,99 @@ + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-multi-checkbox-demo.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-multi-checkbox-demo.html new file mode 100644 index 0000000000..2877a2364f --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-multi-checkbox-demo.html @@ -0,0 +1,177 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-multi-checkbox.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-multi-checkbox.html new file mode 100644 index 0000000000..b45657f7d8 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-multi-checkbox.html @@ -0,0 +1,157 @@ + + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-multi-checkbox.ts b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-multi-checkbox.ts new file mode 100644 index 0000000000..51206dd18c --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-multi-checkbox.ts @@ -0,0 +1,206 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import * as _ from 'lodash' + +import {getObjectInitializer, getObjectObserver, getStringInitializer, getStringObserver} from '../tf_storage_d3v4/storage' + +Polymer({ + is: 'tf-multi-checkbox', + properties: { + names: { + type: Array, + value: function() { + return []; + }, + }, // All the runs in consideration + regexInput: { + type: String, + value: getStringInitializer('regexInput', ''), + observer: '_regexInputObserver', + }, // Regex for filtering the runs + regex: {type: Object, computed: '_makeRegex(regexInput)'}, + namesMatchingRegex: { + type: Array, + computed: 'computeNamesMatchingRegex(names.*, regex)' + }, // Runs that match the regex + runSelectionState: { + // if a run is explicitly enabled, True, if explicitly disabled, False. + // if undefined, default value (enable for first k runs, disable after). + type: Object, + value: getObjectInitializer('runSelectionState', {}), + observer: '_storeRunToIsCheckedMapping', + }, + // (Allows state to persist across regex filtering) + outSelected: { + type: Array, + notify: true, + computed: 'computeOutSelected(namesMatchingRegex.*, runSelectionState.*)' + }, + colorScale: { + type: Object, + observer: 'synchronizeColors', + }, // map from run name to css class + maxRunsToEnableByDefault: { + // When TB first loads, if it has k or fewer runs, they are all enabled + // by default. If there are more, then they are all disabled. + type: Number, + value: 40, + }, + _debouncedRegexChange: { + type: Object, + // Updating the regex can be slow, because it involves updating styles + // on a large number of Polymer paper-checkboxes. We don't want to do + // this while the user is typing, as it may make a bad, laggy UI. + // So we debounce the updates that come from user typing. + value: function() { + const _this = this; + var debounced = _.debounce(function(r) { + _this.regexInput = r; + }, 150, {leading: false}); + return function() { + var r = this.$$('#runs-regex').value; + if (r == '') { + // If the user cleared the field, they may be done typing, so + // update more quickly. + this.async(function() { + _this.regexInput = r; + }, 30); + } else { + debounced(r); + }; + }; + }, + }, + }, + listeners: { + 'dom-change': 'synchronizeColors', + }, + observers: [ + '_setIsolatorIcon(runSelectionState, names)', + ], + _storeRunToIsCheckedMapping: getObjectObserver('runSelectionState', {}), + _makeRegex: function(regex) { + try { + return new RegExp(regex) + } catch (e) { + return null; + } + }, + _setIsolatorIcon: function() { + var runMap = this.runSelectionState; + var numChecked = _.filter(_.values(runMap)).length; + var buttons = + Array.prototype.slice.call(this.querySelectorAll('.isolator')); + + buttons.forEach(function(b) { + if (numChecked === 1 && runMap[b.name]) { + b.icon = 'radio-button-checked'; + } else { + b.icon = 'radio-button-unchecked'; + } + }); + }, + computeNamesMatchingRegex: function(__, ___) { + var regex = this.regex; + return this.names.filter(function(n) { + return regex == null || regex.test(n); + }); + }, + computeOutSelected: function(__, ___) { + var runSelectionState = this.runSelectionState; + var num = this.maxRunsToEnableByDefault; + var allEnabled = this.namesMatchingRegex.length <= num; + return this.namesMatchingRegex.filter(function(n, i) { + return runSelectionState[n] == null ? allEnabled : runSelectionState[n]; + }); + }, + synchronizeColors: function(e) { + if (!this.colorScale) return; + + this._setIsolatorIcon(); + + var checkboxes = + Array.prototype.slice.call(this.querySelectorAll('paper-checkbox')); + var scale = this.colorScale; + checkboxes.forEach(function(p) { + var color = scale.scale(p.name); + p.customStyle['--paper-checkbox-checked-color'] = color; + p.customStyle['--paper-checkbox-checked-ink-color'] = color; + p.customStyle['--paper-checkbox-unchecked-color'] = color; + p.customStyle['--paper-checkbox-unchecked-ink-color'] = color; + }); + var buttons = + Array.prototype.slice.call(this.querySelectorAll('.isolator')); + buttons.forEach(function(p) { + var color = scale.scale(p.name); + p.style['color'] = color; + }); + // The updateStyles call fails silently if the browser doesn't have focus, + // e.g. if TensorBoard was opened into a new tab that isn't visible. + // So we wait for requestAnimationFrame. + var _this = this; + window.requestAnimationFrame(function() { + _this.updateStyles(); + }); + }, + _isolateRun: function(e) { + // If user clicks on the label for one run, enable it and disable all other + // runs. + + var name = (Polymer.dom(e) as any).localTarget.name; + var selectionState = {}; + this.names.forEach(function(n) { + selectionState[n] = n == name; + }); + this.runSelectionState = selectionState; + }, + _checkboxChange: function(e) { + var target = (Polymer.dom(e) as any).localTarget; + this.runSelectionState[target.name] = target.checked; + // n.b. notifyPath won't work because run names may have periods. + this.runSelectionState = _.clone(this.runSelectionState); + }, + _isChecked: function(item, outSelectedChange) { + return this.outSelected.indexOf(item) != -1; + }, + _regexInputObserver: getStringObserver('regexInput', ''), + toggleAll: function() { + var _this = this; + var anyToggledOn = this.namesMatchingRegex.some(function(n) { + return _this.runSelectionState[n] + }); + + + var runSelectionStateIsDefault = + Object.keys(this.runSelectionState).length == 0; + + var defaultOff = + this.namesMatchingRegex.length > this.maxRunsToEnableByDefault; + // We have runs toggled either if some were explicitly toggled on, or if + // we are in the default state, and there are few enough that we default + // to toggling on. + anyToggledOn = anyToggledOn || runSelectionStateIsDefault && !defaultOff; + + // If any are toggled on, we turn everything off. Or, if none are toggled + // on, we turn everything on. + + var newRunsDisabled = {}; + this.names.forEach(function(n) { + newRunsDisabled[n] = !anyToggledOn; + }); + this.runSelectionState = newRunsDisabled; + }, +}); diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-no-data-warning.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-no-data-warning.html new file mode 100644 index 0000000000..d22d341590 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-no-data-warning.html @@ -0,0 +1,129 @@ + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-option-selector.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-option-selector.html new file mode 100644 index 0000000000..d11d43d78a --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-option-selector.html @@ -0,0 +1,94 @@ + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-panes-helper.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-panes-helper.html new file mode 100644 index 0000000000..1a850c8999 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-panes-helper.html @@ -0,0 +1,352 @@ + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group-demo.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group-demo.html new file mode 100644 index 0000000000..bf4cc03f34 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group-demo.html @@ -0,0 +1,46 @@ + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group.html new file mode 100644 index 0000000000..60ec40f29d --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group.html @@ -0,0 +1,97 @@ + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group.ts b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group.ts new file mode 100644 index 0000000000..2c11be0e0b --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group.ts @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {getObjectInitializer, getObjectObserver} from '../tf_storage_d3v4/storage' + +Polymer({ + is: 'tf-regex-group', + properties: { + rawRegexes: { + type: Array, + value: getObjectInitializer('rawRegexes', [{regex: '', valid: true}]), + }, + regexes: + {type: Array, computed: 'usableRegexes(rawRegexes.*)', notify: true}, + }, + observers: [ + 'addNewRegexIfNeeded(rawRegexes.*)', + 'checkValidity(rawRegexes.*)', + '_uriStoreRegexes(rawRegexes.*)', + ], + _uriStoreRegexes: getObjectObserver('rawRegexes', [{regex: '', valid: true}]), + checkValidity: function(x) { + var match = x.path.match(/rawRegexes\.(\d+)\.regex/); + if (match) { + var idx = match[1]; + this.set('rawRegexes.' + idx + '.valid', this.isValid(x.value)); + } + }, + isValid: function(s) { + try { + new RegExp(s); + return true; + } catch (e) { + return false; + } + }, + usableRegexes: function(regexes) { + var isValid = this.isValid; + return regexes.base + .filter(function(r) { + // Checking validity here (rather than using the data property) + // is necessary because otherwise we might send invalid regexes due + // to the fact that this function can call before the observer does + return r.regex !== '' && isValid(r.regex); + }) + .map(function(r) { + return r.regex; + }); + }, + addNewRegexIfNeeded: function() { + var last = this.rawRegexes[this.rawRegexes.length - 1]; + if (last.regex !== '') { + this.push('rawRegexes', {regex: '', valid: true}); + } + }, + deleteRegex: function(e) { + if (this.rawRegexes.length > 1) { + this.splice('rawRegexes', e.model.index, 1); + } + }, + moveFocus: function(e) { + if (e.keyCode === 13) { + var idx = e.model.index; + var inputs = Polymer.dom(this.root).querySelectorAll('.regex-input'); + if (idx < this.rawRegexes.length - 1) { + (inputs[idx + 1] as any).$.input.focus(); + } else { + (document.activeElement as HTMLElement).blur(); + } + } + } +}); diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-run-selector.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-run-selector.html new file mode 100644 index 0000000000..d2b6b1b194 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-run-selector.html @@ -0,0 +1,188 @@ + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-sidebar-helper.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-sidebar-helper.html new file mode 100644 index 0000000000..cde812205b --- /dev/null +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-sidebar-helper.html @@ -0,0 +1,164 @@ + + + + + + + + + + + -- GitLab From 342d315566211a095a06acb1973b94937dadbc0c Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Wed, 26 Apr 2017 14:21:43 -0800 Subject: [PATCH 022/697] Update rematerialization to account for buffer aliasing in memory tracking. Previously rematerialization ignored any buffer aliasing. For one, this caused significant over-rematerialization with while loops. This cl uses points-to analysis to properly account for buffer aliasing. This change includes a couple other related changes: (1) enable a rematerialization to be used by more than one instruction. (2) Avoid rematerializing instructions with control dependencies to avoid invalid aliasing which was prevented by earlier copy-insertion. The net result is generally much fewer instructions rematerialized especially with while loops and reduced memory use. Change: 154352719 --- tensorflow/compiler/xla/service/BUILD | 2 + .../xla/service/hlo_rematerialization.cc | 1181 ++++++++++++----- .../xla/service/hlo_rematerialization.h | 18 + .../xla/service/hlo_rematerialization_test.cc | 221 ++- .../compiler/xla/service/liveness_util.cc | 5 +- .../compiler/xla/service/liveness_util.h | 5 +- 6 files changed, 1050 insertions(+), 382 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index e5a921674f..aed3d72440 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1254,6 +1254,7 @@ cc_library( ":hlo_cost_analysis", ":hlo_dce", ":hlo_ordering", + ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1271,6 +1272,7 @@ cc_test( deps = [ ":cpu_plugin", ":hlo", + ":hlo_matchers", ":hlo_ordering", ":hlo_rematerialization", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 101c9076f8..a2c0b52ab2 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -29,8 +29,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -46,63 +46,58 @@ namespace xla { namespace { -// Returns a vector of the operands of 'instruction' with repeated elements -// removed. -std::vector UniqueOperands(const HloInstruction* instruction) { - std::vector unique_operands; - for (HloInstruction* operand : instruction->operands()) { - if (std::find(unique_operands.begin(), unique_operands.end(), operand) == - unique_operands.end()) { - unique_operands.push_back(operand); - } - } - return unique_operands; -} - // Returns true if the given instruction is rematerializable. bool IsRematerializable(const HloInstruction* instruction) { + // Conservatively, don't rematerialize instruction with control + // dependencies. For one, control dependencies are added to prevent + // interference of aliased buffers (say, in while bodies) and + // rematerialization is ignorant of liveness and may break the intended + // ordering. + if (!instruction->control_predecessors().empty() || + !instruction->control_successors().empty()) { + return false; + } + // Don't rematerialize instructions with side effects, those with a cost that // might not be captured by HloCostAnalysis, or instructions which cannot be // cloned safely. switch (instruction->opcode()) { case HloOpcode::kCall: + case HloOpcode::kConstant: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: case HloOpcode::kOutfeed: case HloOpcode::kInfeed: + case HloOpcode::kParameter: case HloOpcode::kRecv: case HloOpcode::kSend: case HloOpcode::kTrace: case HloOpcode::kWhile: return false; default: - break; - } - - // Skip tuple shapes because we do not currently account for buffer aliasing - // properly which results in improperly accounting of rematerialization cost - // for these shapes. - if (ShapeUtil::IsTuple(instruction->shape())) { - return false; - } - for (auto* operand : instruction->operands()) { - if (ShapeUtil::IsTuple(operand->shape())) { - return false; - } + return true; } - - return true; } -// Class which maintains an ordered list of instructions with fast insertion and -// removal of arbitrary elements. +// Class which maintains an ordered list of instructions with fast insertion +// before arbitrary elements. class InstructionList { public: explicit InstructionList(const std::vector order) { + int64 position = 0; for (const HloInstruction* inst : order) { instructions_.push_back(const_cast(inst)); instruction_iterators_.insert({const_cast(inst), std::next(instructions_.end(), -1)}); + // Initially position numbers are uniquely assigned in order. Later as + // instructions are added with InsertBefore* methods, some instructions + // may have duplicate position numbers, but the values will be guaranteed + // to be monotonically increasing through the list, and so is still useful + // for quickly(-ish) determining the order of arbitrary instructions in + // the list. + position_number_[inst] = position; + first_at_position_[position] = inst; + position++; } } @@ -111,22 +106,63 @@ class InstructionList { return instructions_; } - // Insert instruction 'to_insert' before instruction 'before' in the list. - Status InsertBefore(HloInstruction* to_insert, HloInstruction* before) { + // Insert instruction 'to_insert' immediately before instruction 'before' in + // the list. + void InsertBefore(HloInstruction* to_insert, HloInstruction* before) { + VLOG(3) << "InsertBefore: " << to_insert->name() << " before " + << before->name(); auto it = instruction_iterators_.find(before); - TF_RET_CHECK(it != instruction_iterators_.end()); + CHECK(it != instruction_iterators_.end()); instruction_iterators_.insert( {to_insert, instructions_.insert(it->second, to_insert)}); - return Status::OK(); + // Assign the same position number to the newly added instruction as + // 'before'. This guarantees monotonicity of the position numbers, but not + // uniqueness. + int64 pos = position_number_.at(before); + position_number_[to_insert] = pos; + if (first_at_position_.at(pos) == before) { + first_at_position_[pos] = to_insert; + } } - // Removes instruction from the list. - Status Remove(HloInstruction* instruction) { - auto it = instruction_iterators_.find(instruction); - TF_RET_CHECK(it != instruction_iterators_.end()); - instructions_.erase(it->second); - instruction_iterators_.erase(it); - return Status::OK(); + // Insert instruction 'to_insert' immediately before the earliest instruction + // in 'before_instructions'. + void InsertBeforeInstructions( + HloInstruction* to_insert, + tensorflow::gtl::ArraySlice before_instructions) { + VLOG(3) << "InsertBeforeInstructions: " << to_insert->name() << " before {" + << tensorflow::str_util::Join( + before_instructions, ", ", + [](string* out, HloInstruction* inst) { + tensorflow::strings::StrAppend(out, inst->name()); + }) + << "}"; + + // Find the minimal position number of any instruction in + // 'before_instructions'. + CHECK(!before_instructions.empty()); + int64 min_position_number = std::numeric_limits::max(); + for (const HloInstruction* instruction : before_instructions) { + min_position_number = + std::min(min_position_number, position_number_.at(instruction)); + } + + // Because more than one instruction in 'before_instructions' may have a + // position number of 'min_position_number', find the first such instruction + // with position number 'min_position_number'. + for (auto it = instruction_iterators_.at( + first_at_position_.at(min_position_number)); + it != instructions_.end() && + position_number_.at(*it) == min_position_number; + ++it) { + if (std::find(before_instructions.begin(), before_instructions.end(), + *it) != before_instructions.end()) { + return InsertBefore(to_insert, *it); + } + } + LOG(FATAL) << "Expected to find instruction in before_instructions with " + "position number " + << min_position_number; } private: @@ -137,283 +173,626 @@ class InstructionList { tensorflow::gtl::FlatMap::iterator> instruction_iterators_; + + // A number assigned to each instruction which increases monotonically through + // 'instructions_'. Used to facilitate fast insertion of an instruction before + // the earliest instruction in a set of instructions + // (InsertBeforeInstructions) by enabling fast-ish ordering queries between + // instructions. If position_number_[a] < position_number_[b] then 'a' comes + // before 'b' in the list. If the position numbers are the same then nothing + // can be said about their order without examining the list. + // + // On object construction this value is precisely the instruction's ordinal + // position in the list. Instructions inserted via InsertBefore receive + // duplicate values. However, monotonicity is preserved. + tensorflow::gtl::FlatMap position_number_; + + // The first instruction in the list assigned a particular position number. + tensorflow::gtl::FlatMap first_at_position_; }; +// Return the HloInstructions which use the given LogicalBuffer. Sets +// has_indirect_users to whether any of the uses is indirect. A use is indirect +// if the instruction defining logical_buffer is not an operand of the use. This +// can happen via buffer aliasing (eg, tuples). +std::vector GetUsers( + const LogicalBuffer* logical_buffer, + const TuplePointsToAnalysis& points_to_analysis, bool* has_indirect_users) { + std::vector users; + // To identify uses iterate through all HloInstruction users of the + // BufferAliases of the logical buffer. + *has_indirect_users = false; + for (const BufferAlias& buffer_alias : + points_to_analysis.GetBufferAliases(*logical_buffer)) { + for (const HloInstruction* user : buffer_alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(buffer_alias.instruction(), + buffer_alias.index(), user, + points_to_analysis)) { + // The alias may be an operand of 'user', but the LogicalBuffer cannot + // possibly be used by the instruction so ignore 'user'. This is the + // case, for example, for the tuple element buffers in a GetTupleElement + // instruction (the GTE instruction only uses the pointer vector). + continue; + } + if (buffer_alias.instruction() != logical_buffer->instruction()) { + *has_indirect_users = true; + } + users.push_back(user); + } + } + return users; +} + // Class for tracking memory usage of a computation as the instructions are -// placed sequentially. Memory usage is the sum of live values at the current -// point in the instruction sequence. +// placed sequentially. Memory usage is the sum of the sizes of live values +// (LogicalBuffers) at the current point in the instruction sequence. class MemoryUsageTracker { public: MemoryUsageTracker( const HloComputation* computation, - const HloRematerialization::ShapeSizeFunction& size_function) - : computation_(computation), size_function_(size_function) { - for (const std::unique_ptr& instruction : - computation->instructions()) { - // Initially only live-in values occupy memory. - if (IsLiveIn(instruction.get())) { - memory_usage_ += TotalSizeBytes(instruction->shape()); - } - } + const HloRematerialization::ShapeSizeFunction& size_function, + const TuplePointsToAnalysis& points_to_analysis, + const InstructionList& instruction_list); + + // Starts the placement of the given instruction. This adds the sizes of the + // LogicalBuffers defined by the instruction to the current memory + // usage. Placement is broken into two steps (BeginInstruction and + // EndInstruction) to accurately model memory usage. At BeginInstruction the + // memory for the output value(s) of the current instruction is allocated. At + // EndInstruction memory for dead operand(s) is freed. + Status BeginInstruction(const HloInstruction* instruction); + + // Finishes the placement of the current instruction. This frees any dead + // operands or dead result of the instruction. This must be called after + // each call to BeginInstruction. + Status EndInstruction(); + + // Returns the number of bytes that the current memory usage will be reduced + // if the given instruction is rematerialized. + int64 MemoryReducedIfRematerialized(const HloInstruction* instruction) const; + + // Adjusts memory usage to account for the rematerialization of + // original_instruction for all remaining unplaced uses. The rematerialization + // is remat_instruction. This method should be called after the HLO graph has + // been transformed (rematerialization instruction created and connected to + // uses). + Status AddRematerializedInstruction(HloInstruction* original_instruction, + HloInstruction* remat_instruction); + + // Returns whether the given instruction has been placed (BeginInstruction + // has been called with 'instruction' as the argument). + bool IsPlaced(const HloInstruction* instruction) const { + return ContainsKey(placed_instructions_, instruction); + } + + // Returns the current memory usage. This is the sum of sizes of all live + // values. + int64 memory_usage() const { return memory_usage_; } + + // Returns the current instruction being placed. + const HloInstruction* in_progress_instruction() const { + return in_progress_instruction_; } - // Starts the placement of the given instruction. This adds the output size of - // the instruction to the current memory usage. Placement is broken into two - // steps (BeginInstruction and EndInstruction) to accurately model memory - // usage. At BeginInstruction the memory for the output value of the current - // instruction is allocated. At EndInstruction memory for dead operands is - // freed. - Status BeginInstruction(const HloInstruction* instruction) { - VLOG(3) << "BeginInstruction " << instruction->name(); - TF_RET_CHECK(in_progress_instruction_ == nullptr); - in_progress_instruction_ = instruction; + // Check invariants of the data structure. This is expensive to call. + bool Check() const; - // Add instruction to remaining_uses_. - TF_RET_CHECK(!ContainsKey(remaining_uses_, instruction)); - std::vector& instruction_uses = - remaining_uses_[instruction]; - instruction_uses.insert(instruction_uses.begin(), - instruction->users().begin(), - instruction->users().end()); + string ToString() const; - if (!IsLiveIn(instruction)) { - // Instruction was not previously live so add output size to memory usage. - memory_usage_ += TotalSizeBytes(instruction->shape()); + private: + // Type holding a unique identifier for each Buffer object. + using BufferId = int64; + + // A Buffer represents a single LogicalBuffer in the computation including + // various metadata useful for tracking liveness of the value. A LogicalBuffer + // is not used directly because the HLO graph is transformed and + // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after + // HLO graph transformations. + struct Buffer { + // The unique id of this Buffer. This value is equal to the buffer's index + // in the vector buffers_. + const BufferId id; + + // The instruction which defines this buffer. + const HloInstruction* defining_instruction; + + // The materialized size of the buffer in bytes. + const int64 size; + + // Whether this buffer is live-out of the computation. + bool live_out; + + // Whether this buffer has indirect uses. Ie, an instruction which is not a + // user of defining_instruction uses this buffer. This can occur due to + // buffer aliasing (eg, tuples). + bool has_indirect_uses; + + // The instructions which use this buffer. + std::vector users; + + // The number of users (HloInstructions) of this buffer which have not yet + // been placed in the sequence. + int64 unfinished_user_count; + + string ToString() const { + return tensorflow::strings::StrCat("Buffer ", id, " (defined by ", + defining_instruction->name(), + ", size ", size, " bytes)"); } + }; + + // Creates a Buffer representing the given logical buffer. The buffer is added + // to buffers_ and a reference is returned. + Buffer& CreateBufferFromLogicalBuffer( + const LogicalBuffer* logical_buffer, + const TuplePointsToAnalysis& points_to_analysis, + const HloRematerialization::ShapeSizeFunction& size_function, + bool live_out) { + bool has_indirect_uses = false; + std::vector users = + GetUsers(logical_buffer, points_to_analysis, &has_indirect_uses); + return NewBuffer(logical_buffer->instruction(), + size_function(logical_buffer->shape()), std::move(users), + live_out, has_indirect_uses); + } - VLOG(3) << " memory usage = " << memory_usage_; - VLOG(10) << ToString(); - return Status::OK(); + // Create a new buffer representing a rematerialization of given buffer for + // the given uses. + Buffer& RematerializeBuffer( + const Buffer& original_buffer, const HloInstruction* remat_instruction, + std::vector&& rematerialized_uses) { + CHECK(IsPlaced(original_buffer.defining_instruction)); + CHECK(!original_buffer.has_indirect_uses); + CHECK(!original_buffer.live_out); + for (const HloInstruction* use : rematerialized_uses) { + CHECK(!IsPlaced(use)); + } + return NewBuffer(remat_instruction, original_buffer.size, + std::move(rematerialized_uses), /*live_out=*/false, + /*has_indirect_uses=*/false); } - // Finishes the placement of the current instruction. This frees any dead - // operands or dead result of the instruction. This must be called after each - // call to BeginInstruction. - Status EndInstruction() { - TF_RET_CHECK(in_progress_instruction_ != nullptr); - VLOG(3) << "EndInstruction " << in_progress_instruction_->name(); - - for (HloInstruction* operand : UniqueOperands(in_progress_instruction_)) { - TF_RET_CHECK(ContainsKey(remaining_uses_, operand)); - std::vector& uses = remaining_uses_.at(operand); - auto it = std::find(uses.begin(), uses.end(), in_progress_instruction_); - TF_RET_CHECK(it != uses.end()); - uses.erase(it); - - if (uses.empty()) { - // Operand is dead. - int64 operand_size = TotalSizeBytes(operand->shape()); - if (!IsLiveOut(operand)) { - VLOG(4) << operand->name() << " (" - << HumanReadableNumBytes(operand_size) << ") is dead"; - memory_usage_ -= operand_size; - TF_RET_CHECK(memory_usage_ >= 0); - } - } + // Return number of bytes allocated for the buffer with the given id. Buffers + // allocated by the calling computation (eg, parameter and output buffers) are + // considered to have zero bytes because the memory is accounted for in a + // different computation. + int64 AllocatedSize(BufferId buffer_id) const { + const Buffer& buffer = buffers_.at(buffer_id); + HloOpcode def_opcode = buffer.defining_instruction->opcode(); + if (buffer.live_out || def_opcode == HloOpcode::kParameter) { + return 0; + } else { + return buffer.size; } + } - // Value is dead if the instruction has no uses and is not live out. - if (in_progress_instruction_->users().empty() && - !IsLiveOut(in_progress_instruction_)) { - memory_usage_ -= TotalSizeBytes(in_progress_instruction_->shape()); - TF_RET_CHECK(memory_usage_ >= 0); + // Returns true if BeginInstruction and EndInstruction has been called for the + // given instruction. + bool IsFinished(const HloInstruction* instruction) const { + return IsPlaced(instruction) && instruction != in_progress_instruction_; + } + + // Returns whether the given buffer is being used by the in-progress + // instruction. + bool IsInUse(BufferId buffer_id) const { + if (in_progress_instruction_ == nullptr) { + return false; } + const std::vector& in_progress_uses = + buffers_used_by_instruction_.at(in_progress_instruction_); + return std::find(in_progress_uses.begin(), in_progress_uses.end(), + buffer_id) != in_progress_uses.end(); + } - in_progress_instruction_ = nullptr; + // Returns whether the given instruction is live at the current program + // point. + bool IsCurrentlyLive(BufferId buffer_id) const { + const Buffer& buffer = buffers_[buffer_id]; + return (IsPlaced(buffer.defining_instruction) && + buffer.unfinished_user_count > 0); + } - VLOG(3) << " memory usage = " << memory_usage_; - VLOG(10) << ToString(); - return Status::OK(); + // Create a new buffer, add it to buffers_, and return a reference. + Buffer& NewBuffer(const HloInstruction* defining_instruction, int64 size, + std::vector&& users, bool live_out, + bool has_indirect_uses) { + int buffer_id = buffers_.size(); + buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out, + has_indirect_uses, users, + static_cast(users.size())}); + return buffers_.back(); } - // Adjusts memory usage to account for the rematerialization of - // original_instruction for the given use. The rematerialization is - // remat_instruction. This method should be called after the HLO graph has - // been transformed (rematerialization instruction created and connected to - // its use). - Status RematerializeInstructionForUse(HloInstruction* original_instruction, - HloInstruction* remat_instruction, - HloInstruction* use) { - VLOG(3) << "RematerializeInstructionForUse: original_instruction = " - << original_instruction->name() - << ", remat_instruction = " << remat_instruction->name() - << ", use = " << use->name(); - - TF_RET_CHECK(in_progress_instruction_ != nullptr); - TF_RET_CHECK(IsPlaced(original_instruction)); - TF_RET_CHECK(!IsPlaced(remat_instruction)); - TF_RET_CHECK(!IsPlaced(use)); - TF_RET_CHECK(IsCurrentlyLive(original_instruction)); - - // Remove 'use' from remaining uses of original_instruction. - auto it = std::find(remaining_uses_[original_instruction].begin(), - remaining_uses_[original_instruction].end(), use); - TF_RET_CHECK(it != remaining_uses_[original_instruction].end()); - remaining_uses_[original_instruction].erase(it); - - // If original_instruction is no longer live ('use' was its last use) then - // deduct original_instruction's memory usage. - if (!IsCurrentlyLive(original_instruction)) { - memory_usage_ -= TotalSizeBytes(original_instruction->shape()); - TF_RET_CHECK(memory_usage_ >= 0); - } - - // Add the new remat_instruction to the remaining uses of its operands. - for (auto* operand : UniqueOperands(remat_instruction)) { - // Rematerialization may extend the lifetime of the operand so account for - // this in memory_usage_. - TF_RET_CHECK(IsPlaced(operand)); - if (!IsCurrentlyLive(operand)) { - memory_usage_ += TotalSizeBytes(operand->shape()); + const HloComputation* computation_; + + // Instruction list containing the ordering of instructions in + // computation_. This is the order in which instructions are placed + // (BeginInstruction/EndInstruction calls). + const InstructionList& instruction_list_; + + // Memory usage at the currently placed instruction. + int64 memory_usage_ = 0; + + // The instruction currently being placed. This value is non-null only + // between the calling of BeginInstruction and EndInstruction. + const HloInstruction* in_progress_instruction_ = nullptr; + + // The buffers defined by each instruction. + std::unordered_map> + buffers_defined_by_instruction_; + + // The buffers used by each instruction. + std::unordered_map> + buffers_used_by_instruction_; + + // The set of instructions which have been placed. That is, BeginInstruction + // has been called with the instruction as an argument. + tensorflow::gtl::FlatSet placed_instructions_; + + // All buffers in the computation. + std::vector buffers_; +}; + +MemoryUsageTracker::MemoryUsageTracker( + const HloComputation* computation, + const HloRematerialization::ShapeSizeFunction& size_function, + const TuplePointsToAnalysis& points_to_analysis, + const InstructionList& instruction_list) + : computation_(computation), instruction_list_(instruction_list) { + // Iterate through all LogicalBuffers in the computation and gather the + // instructions which define them in buffers_defined_by_instruction_ and the + // instructions which use them in buffers_used_by_instruction_. + for (auto& instruction : computation_->instructions()) { + // Initialize empty vectors for defs and uses of each instruction. + buffers_used_by_instruction_[instruction.get()]; + buffers_defined_by_instruction_[instruction.get()]; + } + + tensorflow::gtl::FlatSet live_out_set = + points_to_analysis.GetPointsToSet(computation_->root_instruction()) + .CreateFlattenedSet(); + tensorflow::gtl::FlatMap + logical_buffer_to_buffer_id; + + for (const HloInstruction* instruction : instruction_list_.instructions()) { + for (const LogicalBuffer* logical_buffer : + points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { + Buffer* buffer; + if (instruction->opcode() == HloOpcode::kWhile) { + // The while instruction defines no new buffers. Instead it reuses the + // buffers of its operand. Find the Buffer of its operand at the + // proper ShapeIndex. + const PointsToSet& operand_points_to = + points_to_analysis.GetPointsToSet(instruction->operand(0)); + CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1); + const LogicalBuffer* source_logical_buffer = + operand_points_to.element(logical_buffer->index())[0]; + buffer = + &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer)); + + // Mark buffer as has indirect use and live out. + buffer->has_indirect_uses = true; + buffer->live_out = + buffer->live_out || ContainsKey(live_out_set, logical_buffer); + + // Add users of while to Buffer users. + bool unused; + for (const HloInstruction* user : + GetUsers(logical_buffer, points_to_analysis, &unused)) { + if (std::find(buffer->users.begin(), buffer->users.end(), user) == + buffer->users.end()) { + buffer->users.push_back(user); + buffer->unfinished_user_count++; + buffers_used_by_instruction_.at(user).push_back(buffer->id); + } + } + } else { + buffer = &CreateBufferFromLogicalBuffer( + logical_buffer, points_to_analysis, size_function, + ContainsKey(live_out_set, logical_buffer)); + buffers_defined_by_instruction_.at(instruction).push_back(buffer->id); + for (const HloInstruction* user : buffer->users) { + buffers_used_by_instruction_.at(user).push_back(buffer->id); + } } - remaining_uses_.at(operand).push_back(remat_instruction); + + logical_buffer_to_buffer_id[logical_buffer] = buffer->id; } + } + XLA_VLOG_LINES(10, ToString()); + DCHECK(Check()); +} + +Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) { + VLOG(3) << "BeginInstruction " << instruction->name(); + TF_RET_CHECK(in_progress_instruction_ == nullptr); + in_progress_instruction_ = instruction; + + placed_instructions_.insert(in_progress_instruction_); - VLOG(3) << " memory usage = " << memory_usage_; - VLOG(10) << ToString(); - return Status::OK(); + // All buffers defined by this instruction need memory. + for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString() + << " is now live."; + memory_usage_ += AllocatedSize(buffer_id); } - // Returns the number of bytes that the current memory usage will be reduced - // if the given instruction is rematerialized. - int64 MemoryReducedIfRematerialized(const HloInstruction* instruction) const { - // To reduce memory consumption 'instruction' must be currently live and - // rematerialization must make 'instruction' not live. - if (IsLiveIn(instruction) || IsLiveOut(instruction) || - !IsCurrentlyLive(instruction)) { - return 0; - } + // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead) + // operand. Account for this potential reuse here. - // If the in-progress instruction is a user of 'instruction' (or - // 'instruction' itself) then rematerializing 'instruction' cannot reduce - // memory usage because the value is required to be live at this program - // point. - if (in_progress_instruction_ == instruction || - in_progress_instruction_->IsUserOf(instruction)) { - return 0; - } + VLOG(3) << " memory usage = " << memory_usage_; + VLOG(10) << ToString(); - // Compute the amount of memory reduced (if any) by rematerializing - // 'instruction'. 'instruction' will no longer be live at this program - // point, so initially set memory_reduced to the size of its output value. - int64 memory_reduced = TotalSizeBytes(instruction->shape()); + DCHECK(Check()); + return Status::OK(); +} - // Account for any operands whose live range must be extended across this - // program point. - for (const HloInstruction* operand : UniqueOperands(instruction)) { - if (!IsCurrentlyLive(operand)) { - // This operand of candidate is not live at this program - // point. Rematerializing 'instruction' will extend the operand's live - // range across this program point. - memory_reduced -= TotalSizeBytes(operand->shape()); - } +Status MemoryUsageTracker::EndInstruction() { + TF_RET_CHECK(in_progress_instruction_ != nullptr); + VLOG(3) << "EndInstruction " << in_progress_instruction_->name(); + + for (BufferId buffer_id : + buffers_used_by_instruction_.at(in_progress_instruction_)) { + Buffer& buffer = buffers_.at(buffer_id); + buffer.unfinished_user_count--; + CHECK_GE(buffer.unfinished_user_count, 0) + << buffer.ToString() << " has negative unfinished use count."; + if (buffer.unfinished_user_count == 0) { + // Buffer is now dead. + VLOG(3) << " " << buffer.ToString() << " is now dead."; + memory_usage_ -= AllocatedSize(buffer_id); + CHECK_GE(memory_usage_, 0); } - return memory_reduced; } - // Returns the remaining unplaced uses of the given instruction. - const std::vector& RemainingUses( - const HloInstruction* instruction) const { - return remaining_uses_.at(instruction); + // If any buffer defined by this instruction has no uses, then memory can be + // reclaimed immediately. + for (BufferId buffer_id : + buffers_defined_by_instruction_.at(in_progress_instruction_)) { + const Buffer& buffer = buffers_.at(buffer_id); + if (buffer.unfinished_user_count == 0) { + VLOG(3) << " " << buffer.ToString() << " is immediately dead."; + memory_usage_ -= AllocatedSize(buffer_id); + CHECK_GE(memory_usage_, 0); + } } - // Returns whether the given instruction has been placed (BeginInstruction has - // been called with 'instruction' as the argument). - bool IsPlaced(const HloInstruction* instruction) const { - return ContainsKey(remaining_uses_, instruction); + in_progress_instruction_ = nullptr; + + VLOG(3) << " memory usage = " << memory_usage_; + VLOG(10) << ToString(); + + DCHECK(Check()); + + return Status::OK(); +} + +int64 MemoryUsageTracker::MemoryReducedIfRematerialized( + const HloInstruction* instruction) const { + CHECK_NE(in_progress_instruction_, nullptr); + if (!IsPlaced(instruction) || instruction == in_progress_instruction_) { + return 0; } - // Returns whether the given instruction is live at the current program point. - bool IsCurrentlyLive(const HloInstruction* instruction) const { - return (!IsPlaced(instruction) && IsLiveIn(instruction)) || - (IsPlaced(instruction) && - (!RemainingUses(instruction).empty() || IsLiveOut(instruction))); + // TODO(b/37687140): Rematerialization can increase peak memory consumption at + // an earlier point in the program if rematerialization extends the live range + // of the operand of the instruction being rematerialized across the live + // range of the value of instruction being rematerialized. Don't rematerialize + // in this case (ie, return 0 here). + + // Compute the amount of memory reduced (if any) by rematerializing + // 'instruction'. The LogicalBuffers defined by 'instruction' will no longer + // be live at this program point, so initially set memory_reduced to the + // size of its defined values. + int64 memory_reduced = 0; + for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + // Avoid rematerializing instructions with indirect uses as it is difficult + // to reason about liveness after rematerializing the instruction. + // TODO(b/37714814): Consider rematerialzing instructions with indirect + // uses. + if (buffers_.at(buffer_id).has_indirect_uses) { + return 0; + } + + if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) { + memory_reduced += AllocatedSize(buffer_id); + } } - string ToString() const { - string output = tensorflow::strings::StrCat("MemoryUsageTracker for ", - computation_->name(), "\n"); - tensorflow::strings::StrAppend(&output, "memory usage = ", memory_usage(), - "\n"); - tensorflow::strings::StrAppend(&output, "Live values:\n"); - for (const auto& pair : remaining_uses_) { - const HloInstruction* instruction = pair.first; - const std::vector& uses = pair.second; - tensorflow::strings::StrAppend( - &output, " ", instruction->name(), "; remaining uses: ", - tensorflow::str_util::Join(uses, ", ", - [](string* out, HloInstruction* use) { - tensorflow::strings::StrAppend( - out, use->name()); - }), - "\n"); + // Account for any logical buffers whose live range must be extended across + // this program point. + for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) { + if (!IsCurrentlyLive(buffer_id)) { + // This logical buffer is used by 'instruction' but is not live at this + // program point. Rematerializing 'instruction' will extend the buffer's + // live range across this program point. + memory_reduced -= AllocatedSize(buffer_id); } - return output; } - // Returns the current memory usage. This is the sum of sizes of all live - // values. - int64 memory_usage() const { return memory_usage_; } + return memory_reduced; +} - // Returns the current instruction being placed. - const HloInstruction* in_progress_instruction() const { - return in_progress_instruction_; - } +Status MemoryUsageTracker::AddRematerializedInstruction( + HloInstruction* original_instruction, HloInstruction* remat_instruction) { + VLOG(3) << "AddRematerializedInstruction: original_instruction = " + << original_instruction->name() + << ", remat_instruction = " << remat_instruction->name(); + + TF_RET_CHECK(in_progress_instruction_ != nullptr); + TF_RET_CHECK(IsPlaced(original_instruction)); + TF_RET_CHECK(!IsPlaced(remat_instruction)); + CHECK(!ContainsKey(buffers_defined_by_instruction_, remat_instruction)); + CHECK(!ContainsKey(buffers_used_by_instruction_, remat_instruction)); + + // Construct the list of buffers used and defined by the rematerialization. + buffers_defined_by_instruction_[remat_instruction]; + buffers_used_by_instruction_[remat_instruction] = + buffers_used_by_instruction_.at(original_instruction); + + // Account for the additional buffer uses created by the new rematerialization + // instruction. Update memory usage if the rematerialization makes a dead + // buffer live again. + for (BufferId buffer_id : + buffers_used_by_instruction_.at(original_instruction)) { + Buffer& buffer = buffers_.at(buffer_id); + if (buffer.unfinished_user_count == 0) { + // Buffer used by this instruction was dead, now is alive. + memory_usage_ += AllocatedSize(buffer.id); + } - private: - // Returns the total size of the shape (including nested elements) in bytes. - int64 TotalSizeBytes(const Shape& shape) const { - int64 total_size = 0; - ShapeUtil::ForEachSubshape( - shape, - [this, &total_size](const Shape& subshape, - const ShapeIndex& /*index*/) { - total_size += size_function_(subshape); - return Status::OK(); - }) - .IgnoreError(); - return total_size; + buffer.unfinished_user_count++; + buffer.users.push_back(remat_instruction); } - // Returns true if the value of given instruction is live into the - // computation. - bool IsLiveIn(const HloInstruction* instruction) const { - return instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kParameter; + // Create a new set of Buffers defined by the new rematerialization + // instruction. Update the internal data structures and memory use to account + // for them. + for (BufferId old_buffer_id : + buffers_defined_by_instruction_.at(original_instruction)) { + Buffer& old_buffer = buffers_.at(old_buffer_id); + + std::vector placed_users; + std::vector unplaced_users; + for (const HloInstruction* user : old_buffer.users) { + if (IsPlaced(user)) { + CHECK(IsFinished(user)); + placed_users.push_back(user); + } else { + unplaced_users.push_back(user); + } + } + old_buffer.users = std::move(placed_users); + old_buffer.unfinished_user_count = 0; + + // Buffer is now dead. + memory_usage_ -= AllocatedSize(old_buffer.id); + + Buffer& new_buffer = RematerializeBuffer(old_buffer, remat_instruction, + std::move(unplaced_users)); + + buffers_defined_by_instruction_.at(remat_instruction) + .push_back(new_buffer.id); + for (const HloInstruction* user : new_buffer.users) { + std::vector& buffers_used = + buffers_used_by_instruction_.at(user); + std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id, + new_buffer.id); + } } - // Returns true if the value of given instruction is live out of the - // computation. - bool IsLiveOut(const HloInstruction* instruction) const { - return instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kParameter || - instruction == instruction->parent()->root_instruction(); + VLOG(3) << " memory usage = " << memory_usage_; + XLA_VLOG_LINES(10, ToString()); + + DCHECK(Check()); + + return Status::OK(); +} + +string MemoryUsageTracker::ToString() const { + string output = tensorflow::strings::StrCat("MemoryUsageTracker for ", + computation_->name(), "\n"); + tensorflow::strings::StrAppend( + &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", + memory_usage(), " bytes)"); + for (const HloInstruction* instruction : instruction_list_.instructions()) { + string inprogress = + instruction == in_progress_instruction_ ? " in-progress" : ""; + string placed = IsPlaced(instruction) ? " placed" : ""; + tensorflow::strings::StrAppend(&output, " ", instruction->name(), + inprogress, placed, "\n Defines:\n"); + for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + const Buffer& buffer = buffers_[buffer_id]; + string live = IsCurrentlyLive(buffer_id) ? " live" : ""; + tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live, + ", ", buffer.unfinished_user_count, + " unfinished uses\n"); + } + tensorflow::strings::StrAppend(&output, " Uses:\n"); + for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) { + tensorflow::strings::StrAppend(&output, " ", + buffers_[buffer_id].ToString(), "\n"); + } } + return output; +} - const HloComputation* computation_; +bool MemoryUsageTracker::Check() const { + auto elements_are_unique = [](const std::vector& vec) { + return vec.size() == std::set(vec.begin(), vec.end()).size(); + }; + + // Verify buffers_defined_by_instruction_. + for (auto& instruction : computation_->instructions()) { + const std::vector& defined_buffers = + buffers_defined_by_instruction_.at(instruction.get()); + CHECK(elements_are_unique(defined_buffers)) + << "Instruction " << instruction->name() + << " does not have unique defined buffers: " + << tensorflow::str_util::Join( + defined_buffers, ", ", [this](string* out, BufferId buffer_id) { + tensorflow::strings::StrAppend( + out, buffers_.at(buffer_id).ToString()); + }); - // Function which computes the size of the top-level buffer of a shape. - const HloRematerialization::ShapeSizeFunction size_function_; + for (const Buffer& buffer : buffers_) { + if (buffer.defining_instruction == instruction.get()) { + CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), + buffer.id) != defined_buffers.end()) + << "Instruction " << instruction->name() + << " defined buffers is missing: " << buffer.ToString(); + } + } + } - // Memory usage at the currently placed instruction. - int64 memory_usage_ = 0; + // Verify buffers_used_by_instruction_. + for (auto& instruction : computation_->instructions()) { + const std::vector& used_buffers = + buffers_used_by_instruction_.at(instruction.get()); + CHECK(elements_are_unique(used_buffers)) + << "Instruction " << instruction->name() + << " does not have unique used buffers: " + << tensorflow::str_util::Join( + used_buffers, ", ", [this](string* out, BufferId buffer_id) { + tensorflow::strings::StrAppend( + out, buffers_.at(buffer_id).ToString()); + }); + } + for (const Buffer& buffer : buffers_) { + int64 unfinished_uses = 0; + for (const HloInstruction* user : buffer.users) { + const std::vector& used_buffers = + buffers_used_by_instruction_.at(user); + CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) != + used_buffers.end()) + << "Instruction " << user->name() << " used buffers is missing " + << buffer.ToString(); + if (!IsFinished(user)) { + unfinished_uses++; + } + } + CHECK_EQ(buffer.unfinished_user_count, unfinished_uses) + << "Incorrect unplaced use count for " << buffer.ToString(); + } - // The instruction currently being placed. This value is non-null only between - // the calling of BeginInstruction and EndInstruction. - const HloInstruction* in_progress_instruction_ = nullptr; + // Verify live set size against memory_usage_. + int64 live_size = 0; + for (const Buffer& buffer : buffers_) { + // The while instruction reuses its input buffers as output buffers so + // don't double count its buffers if it is currently executing. + if (IsCurrentlyLive(buffer.id) && + !(buffer.defining_instruction == in_progress_instruction_ && + in_progress_instruction_->opcode() == HloOpcode::kWhile)) { + live_size += AllocatedSize(buffer.id); + } + } + CHECK_EQ(live_size, memory_usage_); - // remaining_uses is a vector of uses of the HLO instruction's value which - // have not yet been visited by in the rematerialization loop. Use to track - // liveness of HLO instructions. - // TODO(b/35212854): Track values using logical buffers rather than HLO - // instructions. Using HLO instructions over-estimates memory usage because - // buffer aliasing is ignored. - tensorflow::gtl::FlatMap> - remaining_uses_; -}; + return true; +} -// Computes and returns the cost of rematerializing the given instruction. Cost -// per rematerialized instruction is defined as: +// Computes and returns the cost of rematerializing the given instruction. +// Cost per rematerialized instruction is defined as: // // (flop_count + transcendental_count + element_count) / memory_reduced // @@ -425,33 +804,36 @@ class MemoryUsageTracker { // instruction. // // This is a rough estimate of the extra execution time per byte saved by -// rematerializing this instruction for its remaining uses. In general, we want -// the most memory saving for the least latency penalty which is captured by -// this heuristic. +// rematerializing this instruction for its remaining uses. In general, we +// want the most memory saving for the least latency penalty which is captured +// by this heuristic. int64 RematerializationCost(const HloInstruction* instruction, const MemoryUsageTracker& memory_tracker, const HloCostAnalysis& cost_analysis, int64 memory_reduced) { - const int64 bytes_accessed = cost_analysis.bytes_accessed(*instruction); - const int64 elements_accessed = - bytes_accessed / - ShapeUtil::ByteSizeOfPrimitiveType(instruction->shape().element_type()); - - // A duplicate of the rematerialized instruction will be created at each - // remaining use. - int64 duplication = memory_tracker.RemainingUses(instruction).size(); - if (duplication == instruction->users().size()) { - // All remaining uses of instruction are after this point so we can remove - // the original instruciton after rematerialization. - duplication -= 1; + // If none of the users of 'instruction' have been placed in the sequence (as + // tracked by memory_tracker), then rematerialization of 'instruction' is a + // zero-cost move of 'instruction' in the sequence. + if (!std::any_of(instruction->users().begin(), instruction->users().end(), + [&memory_tracker](const HloInstruction* inst) { + return memory_tracker.IsPlaced(inst); + })) { + return 0; } + CHECK_GT(memory_reduced, 0); + const int64 bytes_accessed = cost_analysis.bytes_accessed(*instruction); + const int64 elements_accessed = + ShapeUtil::IsTuple(instruction->shape()) + ? bytes_accessed + : bytes_accessed / ShapeUtil::ByteSizeOfPrimitiveType( + instruction->shape().element_type()); // Multiply by 256 to improve precision of cost. Without this factor, // many instructions such as many elementwise instructions would have // zero cost because the bytes reduced can be several times greater than // the element count. - return 256 * duplication * + return 256 * (cost_analysis.flop_count(*instruction) + cost_analysis.transcendental_count(*instruction) + elements_accessed) / @@ -467,7 +849,7 @@ HloInstruction* PickRematerializationCandidate( const MemoryUsageTracker& memory_tracker, const InstructionList& instruction_list, const HloCostAnalysis& cost_analysis, - const tensorflow::gtl::FlatSet& remat_instructions) { + const tensorflow::gtl::FlatSet& blacklist) { HloInstruction* best = nullptr; int64 best_cost = 0; @@ -482,11 +864,11 @@ HloInstruction* PickRematerializationCandidate( } VLOG(5) << "considering rematerialization candidate " << candidate->name(); - if (ContainsKey(remat_instructions, candidate)) { - // Skip instructions which are rematerialization clones to avoid infinite - // loops of rematerializing the same instruction(s) repeatedly. + if (ContainsKey(blacklist, candidate)) { + // Skip instructions on the blacklist to avoid infinite loops of + // rematerializing the same instruction(s) repeatedly. VLOG(5) << "candidate " << candidate->name() - << " not viable: is a rematerialized instruction"; + << " is excluded from rematerialization"; continue; } @@ -525,7 +907,9 @@ HloInstruction* PickRematerializationCandidate( StatusOr HloRematerialization::ComputePeakMemory( const HloComputation* computation, const std::vector& order) const { - MemoryUsageTracker tracker(computation, size_function_); + InstructionList instruction_list(order); + MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, + instruction_list); int64 peak_memory = tracker.memory_usage(); for (const HloInstruction* instruction : order) { TF_RETURN_IF_ERROR(tracker.BeginInstruction(instruction)); @@ -564,15 +948,24 @@ StatusOr HloRematerialization::RematerializeComputation( << " with limit " << HumanReadableNumBytes(memory_limit_bytes); VLOG(1) << "peak memory usage is " << HumanReadableNumBytes(computation_peak_memory_.at(computation)); + CHECK(!ContainsKey(rematerialized_computations_, computation)); InstructionList instruction_list(sequence->at(computation)); - MemoryUsageTracker memory_tracker(computation, size_function_); + MemoryUsageTracker memory_tracker(computation, size_function_, + *points_to_analysis_, instruction_list); bool changed = false; - // Set of instruction clones (not the originals) created during - // rematerialization. A record is kept to avoid rematerializing an instruction - // more than once to avoid looping infinitely during rematerialization. - tensorflow::gtl::FlatSet remat_instructions; + // To avoid an infinite loop rematerializing the same set of instructions ad + // infinitum, keep a blacklist of instructions which should not be + // rematerialized. + tensorflow::gtl::FlatSet blacklist; + + // If the rematerialization makes the source instruction dead, then the + // rematerialization is added to 'remat_move_instructions' (the + // rematerialization is essentially a move). If the next rematerialization of + // the instruction is also a move then the rematerialization is added to the + // blacklist. + tensorflow::gtl::FlatSet remat_move_instructions; // The peak memory of the computation at any point in the instruction // sequence. @@ -590,6 +983,7 @@ StatusOr HloRematerialization::RematerializeComputation( // Iterate through all instructions in the sequence. At each instruction // (program point) if memory_usage exceeds the specified limit then // rematerialize HLO instructions until memory_usage is reduced. + int64 instruction_index = 0; for (auto list_it = instruction_list.instructions().begin(); list_it != instruction_list.instructions().end(); ++list_it) { HloInstruction* instruction = *list_it; @@ -599,7 +993,9 @@ StatusOr HloRematerialization::RematerializeComputation( VLOG(2) << "Program point at " << instruction->name() << ", memory usage = " << memory_tracker.memory_usage() - << ", callee usage = " << callee_usage; + << ", callee usage = " << callee_usage << ", [" << instruction_index + << "/" << instruction_list.instructions().size() << "]"; + instruction_index++; while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) { VLOG(2) << "Over memory limit at instruction " << instruction->name() @@ -609,7 +1005,7 @@ StatusOr HloRematerialization::RematerializeComputation( << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); HloInstruction* best = PickRematerializationCandidate( - memory_tracker, instruction_list, cost_analysis_, remat_instructions); + memory_tracker, instruction_list, cost_analysis_, blacklist); if (best == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " @@ -620,44 +1016,42 @@ StatusOr HloRematerialization::RematerializeComputation( break; } - VLOG(1) << "Rematerializing instruction " << best->name(); + VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " + << memory_tracker.MemoryReducedIfRematerialized(best) << ")"; changed = true; remat_count++; - // Create a rematerialized copy of the candidate at each remaining use. - // Make a copy of remaining uses because RematerializeInstructionForUse - // modifies the remaining uses vector in memory_tracker. - // TODO(b/35213652): It may be profitable to share one rematerialized copy - // amongst more than one use. - std::vector remaining_uses_copy = - memory_tracker.RemainingUses(best); - for (HloInstruction* use : remaining_uses_copy) { - // Create a new rematerialized instruction in the HLO graph. - HloInstruction* remat = - computation->AddInstruction(best->Clone(/*suffix=*/"remat")); - - VLOG(3) << "Replacing use of " << best->name() << " in " << use->name() - << " with rematerialization " << remat->name(); - - TF_RETURN_IF_ERROR(best->ReplaceUseWith(use, remat)); - - // Account for the rematerialization in the memory tracker. - TF_RETURN_IF_ERROR( - memory_tracker.RematerializeInstructionForUse(best, remat, use)); + HloInstruction* remat = + computation->AddInstruction(best->Clone(/*suffix=*/"remat")); - // Insert rematerialized instruction right before its use. - TF_RETURN_IF_ERROR(instruction_list.InsertBefore(remat, use)); - - // Add rematerialized instruction to remat_instructions so the - // rematerialized instruction is not rematerialized again. - remat_instructions.insert(remat); - - net_instructions_added++; + // Replace each remaining use of 'best' with the rematerialization. + std::vector best_users_copy = best->users(); + for (HloInstruction* user : best_users_copy) { + if (!memory_tracker.IsPlaced(user)) { + VLOG(2) << " Replacing use of " << best->name() << " in " + << user->name() << " with " << remat->name(); + TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat)); + } } - // Original instruction should no longer be live at this point. All - // of its remaining uses are fed by rematerialized instructions. - TF_RET_CHECK(!memory_tracker.IsCurrentlyLive(best)); + // Account for the rematerialization in the memory tracker. + TF_RETURN_IF_ERROR( + memory_tracker.AddRematerializedInstruction(best, remat)); + + // Insert rematerialized instruction right before the earliest unplaced + // use of the instruction *and* the earliest unplaced last use of any + // operands of remat. Unplaced uses of the remat's operands are included + // because we don't want to extend the live range of remat's operands as + // this could increase memory usage. + std::vector place_before = remat->users(); + for (auto* operand : remat->operands()) { + for (auto* operand_user : operand->users()) { + if (!memory_tracker.IsPlaced(operand_user) && operand_user != remat) { + place_before.push_back(operand_user); + } + } + } + instruction_list.InsertBeforeInstructions(remat, place_before); // If the rematerialized instruction is dead then rematerialization is // essentially a move. Don't delete the instruction now because we don't @@ -665,8 +1059,17 @@ StatusOr HloRematerialization::RematerializeComputation( // transformation because we keep maps with HloInstruction* values as // keys. if (best->users().empty()) { - VLOG(3) << best->name() << " is now dead"; - net_instructions_added--; + VLOG(2) << best->name() << " is now dead"; + if (ContainsKey(remat_move_instructions, best)) { + // Previously, 'best' was a rematerialization which killed the + // instruction it was a copying of. Now 'remat' is a rematerialization + // of 'best' and kills 'best'. Stop rematerializing this instruction + // to avoid an infinite loop. + blacklist.insert(remat); + } + remat_move_instructions.insert(remat); + } else { + net_instructions_added++; } VLOG(3) << "memory_usage after rematerialization = " @@ -687,21 +1090,22 @@ StatusOr HloRematerialization::RematerializeComputation( // Recompute callee usage to account for any rematerialization performed // in the callee computations. - callee_usage = 0; for (HloComputation* called_computation : callsite->called_computations()) { - // Memory limit for the subcomputation is the memory limit less the - // amount of memory used at this point in the computation. - int64 subcomputation_memory_limit_bytes = std::max( - 0, memory_limit_bytes - memory_tracker.memory_usage()); - TF_ASSIGN_OR_RETURN( - bool subcomputation_changed, - RematerializeComputation(called_computation, sequence, - subcomputation_memory_limit_bytes)); - changed |= subcomputation_changed; - - callee_usage += computation_peak_memory_.at(called_computation); + if (!ContainsKey(rematerialized_computations_, called_computation)) { + // Memory limit for the subcomputation is the memory limit less the + // amount of memory used at this point in the computation. + int64 subcomputation_memory_limit_bytes = std::max( + 0, memory_limit_bytes - memory_tracker.memory_usage()); + TF_ASSIGN_OR_RETURN( + bool subcomputation_changed, + RematerializeComputation(called_computation, sequence, + subcomputation_memory_limit_bytes)); + changed |= subcomputation_changed; + } } + TF_ASSIGN_OR_RETURN(callee_usage, + CalledComputationsMemoryUsage(instruction)); } peak_memory = std::max(peak_memory, @@ -711,37 +1115,33 @@ StatusOr HloRematerialization::RematerializeComputation( TF_RETURN_IF_ERROR(memory_tracker.EndInstruction()); } - if (peak_memory > memory_limit_bytes) { - LOG(WARNING) << "Can't reduce memory use of computation " - << computation->name() << " below " - << HumanReadableNumBytes(memory_limit_bytes) - << " by rematerialization (only reduced to " - << HumanReadableNumBytes(peak_memory) << ")"; - } - - // Verify that there are no more remaining uses. + // Verify some invariants on the memory tracker. + CHECK_EQ(memory_tracker.memory_usage(), 0); for (auto& instruction : computation->instructions()) { - auto& remaining_uses = memory_tracker.RemainingUses(instruction.get()); - CHECK(remaining_uses.empty()) - << instruction->name() << " has remaining uses: " - << tensorflow::str_util::Join( - remaining_uses, ", ", [](string* out, HloInstruction* inst) { - tensorflow::strings::StrAppend(out, inst->name()); - }); + CHECK(memory_tracker.IsPlaced(instruction.get())); } - VLOG(1) << "Rematerialized " << remat_count << " instructions; " - << net_instructions_added << " net instructions added"; - VLOG(1) << "peak memory usage now " << HumanReadableNumBytes(peak_memory); + VLOG(1) << "In computation " << computation->name() << " rematerialized " + << remat_count << " instructions; " << net_instructions_added + << " net instructions added"; + VLOG(1) << " peak memory usage now " << HumanReadableNumBytes(peak_memory) + << " (was " + << HumanReadableNumBytes(computation_peak_memory_.at(computation)) + << ")"; // Update peak memory used by computation. - computation_peak_memory_[computation] = peak_memory; + computation_peak_memory_.at(computation) = peak_memory; // Update order to include rematerialized instructions. sequence->at(computation) .assign(instruction_list.instructions().begin(), instruction_list.instructions().end()); + rematerialized_computations_.insert(computation); + + instructions_rematerialized_ += remat_count; + net_instructions_added_ += net_instructions_added; + return changed; } @@ -754,6 +1154,33 @@ StatusOr HloRematerialization::Run( VLOG(1) << "HloRematerialization() with memory limit of " << HumanReadableNumBytes(memory_limit_bytes); + TF_ASSIGN_OR_RETURN(points_to_analysis_, + TuplePointsToAnalysis::Run( + module, /*include_loop_fusion_instructions=*/true)); + + // Adjust memory limit to account for the parameter and output of the entry + // computation. This is necessary because the per-computation accounting in + // MemoryUsageTracker do not include parameters and output as these are + // typically allocated by the caller. With this adjustment the memory limit + // accounts for the size of all HLO instructions (parameters, output + // instructions, etc). + auto total_size = [this](const HloInstruction* instruction) { + int64 total_size = 0; + for (const LogicalBuffer* logical_buffer : + points_to_analysis_->GetBuffersDefinedByInstruction(instruction)) { + total_size += size_function_(logical_buffer->shape()); + } + return total_size; + }; + const HloComputation* entry_computation = module->entry_computation(); + memory_limit_bytes -= total_size(entry_computation->root_instruction()); + for (const HloInstruction* param : + entry_computation->parameter_instructions()) { + memory_limit_bytes -= total_size(param); + } + VLOG(1) << "Adjusted memory limit accounting for parameters and output: " + << HumanReadableNumBytes(memory_limit_bytes); + XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Create initial sequence of HLO instructions. TF_ASSIGN_OR_RETURN(*sequence, @@ -761,7 +1188,6 @@ StatusOr HloRematerialization::Run( *module, [this](const LogicalBuffer& buffer) { return size_function_(buffer.shape()); })); - // Compute peak memory usage of all computations in the module called in a // sequential context. TF_ASSIGN_OR_RETURN(call_graph_, CallGraph::Build(module)); @@ -776,9 +1202,10 @@ StatusOr HloRematerialization::Run( return Status::OK(); })); + const int64 before_peak_memory = + computation_peak_memory_.at(module->entry_computation()); VLOG(1) << "Peak memory usage of module (before): " - << HumanReadableNumBytes( - computation_peak_memory_[module->entry_computation()]); + << HumanReadableNumBytes(before_peak_memory); // Run cost analysis. Operation cost is used in the heuristic for selecting // instructions for rematerialization. @@ -824,19 +1251,37 @@ StatusOr HloRematerialization::Run( computation->instruction_count()); } } - - VLOG(1) << "Peak memory usage of module (after): " - << HumanReadableNumBytes( - computation_peak_memory_[module->entry_computation()]); + VLOG(1) << "Rematerialized " << instructions_rematerialized_ + << " instructions in module " << module->name() << "; " + << net_instructions_added_ << " net instructions added"; + const int64 current_peak_memory = + computation_peak_memory_.at(module->entry_computation()); + VLOG(1) << "Peak memory usage of module now " + << HumanReadableNumBytes(current_peak_memory) << " (" + << current_peak_memory << " bytes), was " + << HumanReadableNumBytes(before_peak_memory) << " (" + << before_peak_memory << " bytes)"; + const int64 reduced_peak_memory = before_peak_memory - current_peak_memory; + VLOG(1) << "Reduced peak memory by " + << HumanReadableNumBytes(reduced_peak_memory) << " (" + << reduced_peak_memory << " bytes)"; XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); + if (current_peak_memory > memory_limit_bytes) { + LOG(WARNING) << "Can't reduce memory use below " + << HumanReadableNumBytes(memory_limit_bytes) + << " by rematerialization (only reduced to " + << HumanReadableNumBytes(current_peak_memory) << ")"; + } + return changed; } /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( - const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, SequentialHloOrdering::HloModuleSequence* sequence) { + const HloRematerialization::ShapeSizeFunction& size_function, + int64 memory_limit_bytes, HloModule* hlo_module, + SequentialHloOrdering::HloModuleSequence* sequence) { HloRematerialization remat(size_function); return remat.Run(hlo_module, sequence, memory_limit_bytes); } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 86e1998b89..1693f93183 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -21,6 +21,7 @@ #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" namespace xla { @@ -108,6 +109,23 @@ class HloRematerialization { // occurs. tensorflow::gtl::FlatMap computation_peak_memory_; + + std::unique_ptr points_to_analysis_; + + // Set of computations which have had rematerialization + // applied. Rematerialization is only applied once per computation. + tensorflow::gtl::FlatSet rematerialized_computations_; + + // Count of the total instructions rematerialized. + int64 instructions_rematerialized_ = 0; + + // Count of the net instructions added to the HLO module by + // rematerialization. This can be different than instructions_rematerialized_ + // because some rematerializations are effectively moves in the HLO + // schedule. In these cases, the rematerialization instruction replaces all + // uses of the original instruction and the original instruction is + // dead. Hence, no net instructions were added. + int64 net_instructions_added_ = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 0a4f277689..82de1c835b 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -27,15 +28,17 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { -class HloOrderingTest : public HloTestBase { +class HloRematerializationTest : public HloTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: // - // F32[1] %param = {...} + // F32[] %param = {...} // F32[1024] %bcast = broadcast(%param) // F32[1024] %negate = negate(%bcast) // F32[2048] %concat_1 = concat({%negate, %negate}) @@ -52,7 +55,7 @@ class HloOrderingTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, vec1_shape_, "param")); + HloInstruction::CreateParameter(0, scalar_shape_, "param")); auto bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); auto negate = builder.AddInstruction( @@ -77,7 +80,7 @@ class HloOrderingTest : public HloTestBase { // Creates and returns a computation which includes a while and can benefit // from rematerialization. The computation looks like: // - // F32[1] %param = {...} + // F32[] %param = {...} // F32[1024] %bcast = broadcast(%param) // F32[1] %slice_1 = slice(%bcast, {0:1}) // F32[1] %while = while(%slice_1, while_body, while_cond) @@ -93,7 +96,7 @@ class HloOrderingTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, vec1_shape_, "param")); + HloInstruction::CreateParameter(0, scalar_shape_, "param")); auto bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); auto slice_1 = builder.AddInstruction( @@ -127,13 +130,14 @@ class HloOrderingTest : public HloTestBase { } // Various shapes used in the canned computations. + const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {}); const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1}); const Shape vec1024_shape_ = ShapeUtil::MakeShape(xla::F32, {1024}); }; // Test rematerialization of a single computation produced by // MakeRematerializableComputation. -TEST_F(HloOrderingTest, SingleComputation) { +TEST_F(HloRematerializationTest, SingleComputation) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(MakeRematerializableComputation()); @@ -175,7 +179,7 @@ TEST_F(HloOrderingTest, SingleComputation) { // Test rematerialization of a single computation produced by // MakeRematerializableComputation but with a sufficiently high memory limit // such that no instructions are rematerialized. -TEST_F(HloOrderingTest, SingleComputationNoRematerialization) { +TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(MakeRematerializableComputation()); @@ -199,7 +203,7 @@ TEST_F(HloOrderingTest, SingleComputationNoRematerialization) { // only one computation needs to have an instruction rematerialized. The entry // computation should be the one chosen because rematerialization in the while // will presumably be more expensive. -TEST_F(HloOrderingTest, RematerializeAroundWhile) { +TEST_F(HloRematerializationTest, RematerializeAroundWhile) { HloModule module(TestName()); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); @@ -237,7 +241,7 @@ TEST_F(HloOrderingTest, RematerializeAroundWhile) { // Test rematerialization of a computation which calls another computation via a // while. Both the entry computation and while body computation should have // computations rematerialized. -TEST_F(HloOrderingTest, RematerializeEntryAndWhileBody) { +TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { HloModule module(TestName()); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); @@ -271,7 +275,7 @@ TEST_F(HloOrderingTest, RematerializeEntryAndWhileBody) { // Test rematerialization of a doubly nested computation. All computations // should have an instruction rematerialized. -TEST_F(HloOrderingTest, RematerializeNestedComputations) { +TEST_F(HloRematerializationTest, RematerializeNestedComputations) { HloModule module(TestName()); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); @@ -311,6 +315,203 @@ TEST_F(HloOrderingTest, RematerializeNestedComputations) { EXPECT_EQ(inner_computation->instruction_count(), 8); } +TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { + // Test that a single instruction is rematerialized several times. Module: + // + // Entry computation: + // F32[] %param = {...} + // F32[1024] %bcast = broadcast(%param) + // F32[1024] %add_1 = add(%bcast, bcast) + // F32[1024] %call_1 = call(Subcomputation, {%add_1}) + // F32[1024] %add_2 = add(%bcast, call_1) + // F32[1024] %call_2 = call(SubComputation, {%add_2}) + // F32[1024] %add_3 = add(%bcast, call_2) + // F32[1024] %call_3 = call(Subcomputation, {%add_3}) + // F32[1024] %add_4 = add(%bcast, call_3) + // + // Subcomputation: + // F32[1024] %param = {...} + // F32[2048] %concat = concat({%param, %param}) + // F32[1024] %slice = slice(%concat) + // + // The value %bcast is live across each call of Subcomputation (which requires + // 8KB) though the value is not used in the calls. Rematerializing %bcast + // across these calls reduces peak memory use from ~20KB down to ~16KB. + HloModule module(TestName()); + + HloComputation* subcomputation = nullptr; + { + auto builder = HloComputation::Builder(TestName() + ".subcomputation"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1024_shape_, "param")); + auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {2048}), {param, param}, + /*dimension=*/0)); + builder.AddInstruction(HloInstruction::CreateSlice( + vec1024_shape_, concat, /*start_indices=*/{0}, + /*limit_indices=*/{1024})); + subcomputation = module.AddEmbeddedComputation(builder.Build()); + } + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, bcast)); + auto call_1 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation)); + auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_1)); + auto call_2 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_2}, subcomputation)); + auto add_3 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_2)); + auto call_3 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_3}, subcomputation)); + auto add_4 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_3)); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + + auto count_broadcasts = [](const HloComputation* computation) { + int64 bcast_count = 0; + for (auto& instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBroadcast) { + bcast_count++; + } + } + return bcast_count; + }; + + // Before rematerialization there should be a single broadcast instruction in + // the graph. + EXPECT_EQ(count_broadcasts(entry_computation), 1); + EXPECT_EQ(entry_computation->instruction_count(), 9); + + EXPECT_EQ(add_2->operand(0), bcast); + EXPECT_EQ(add_3->operand(0), bcast); + EXPECT_EQ(add_4->operand(0), bcast); + + SequentialHloOrdering::HloModuleSequence sequence; + // Pick a memory limit some where between 24KB (initial peak memory including + // parameter and output) and 20KB (peak memory possible with + // rematerialization). + TF_ASSIGN_OR_ASSERT_OK( + bool changed, HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, &module, &sequence)); + EXPECT_TRUE(changed); + + // The broadcast should have been rematerialized 3 times. + EXPECT_EQ(count_broadcasts(entry_computation), 4); + EXPECT_EQ(entry_computation->instruction_count(), 12); + + // The operands of add_2, add_3, and add_4 should all be rematerialized + // broadcasts. + EXPECT_NE(add_2->operand(0), bcast); + EXPECT_THAT(add_2->operand(0), op::Broadcast(param)); + EXPECT_NE(add_3->operand(0), bcast); + EXPECT_THAT(add_3->operand(0), op::Broadcast(param)); + EXPECT_NE(add_4->operand(0), bcast); + EXPECT_THAT(add_4->operand(0), op::Broadcast(param)); +} + +class IndirectUseTest : public HloRematerializationTest, + public ::testing::WithParamInterface {}; + +TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { + // Test that an rematerializable instruction is not rematerialized if it has + // an indirect use. Test is parameterized on whether the value has an indirect + // use, and the instruction should be rematerialized iff the value has no + // indirect use. Module: + // + // Entry computation: + // F32[] %param = {...} + // F32[1024] %bcast = broadcast(%param) + // F32[1024] %add_1 = add(%bcast, bcast) + // F32[1024] %call = call(Subcomputation, {%add_1}) + // F32[1024] %add_2 = add(%bcast, call) + // {F32[1024], F32[1024]} %tuple = tuple(%bcast, %add_2) + // F32[1024] %gte = GetTupleElememt(%tuple, 0) + // F32[1024] %negate = negate(%gte) + // + // Subcomputation: + // F32[1024] %param = {...} + // F32[2048] %concat = concat({%param, %param}) + // F32[1024] %slice = slice(%concat) + // + // The value %bcast is live across the call and rematerialization of %bcast + // across that point would reduce peak memory use by 4KB. However, %bcast is + // used indirectly in the %negate so rematerialization should not happen. + // + // This test is parameterized on whether the broadcast has an indirect use or + // not. The indirect use is controlled by the index of the GetTupleElement + // instruction. If the element is 0, then the %negate operand aliases %bcast + // (ie %bcast is used indirectly by %negate), otherwise the %negate operand + // aliases %add_2. + const bool indirectly_used = GetParam(); + HloModule module(TestName()); + + HloComputation* subcomputation = nullptr; + { + auto builder = HloComputation::Builder(TestName() + ".subcomputation"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1024_shape_, "param")); + auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {2048}), {param, param}, + /*dimension=*/0)); + builder.AddInstruction(HloInstruction::CreateSlice( + vec1024_shape_, concat, /*start_indices=*/{0}, + /*limit_indices=*/{1024})); + subcomputation = module.AddEmbeddedComputation(builder.Build()); + } + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, bcast)); + auto call_1 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation)); + auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_1)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({bcast, add_2})); + auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + vec1024_shape_, tuple, indirectly_used ? 0 : 1)); + builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, gte)); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(entry_computation->instruction_count(), 8); + + SequentialHloOrdering::HloModuleSequence sequence; + // Pick a memory limit some where between 24KB (initial peak memory including + // parameter and output) and 20KB (peak memory possible with + // rematerialization). + TF_ASSIGN_OR_ASSERT_OK( + bool changed, HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, &module, &sequence)); + // Rematerialization should only occur if the rematerializable instruction has + // no indirect uses. + if (indirectly_used) { + EXPECT_FALSE(changed); + EXPECT_EQ(entry_computation->instruction_count(), 8); + } else { + EXPECT_TRUE(changed); + EXPECT_EQ(entry_computation->instruction_count(), 9); + } +} + +INSTANTIATE_TEST_CASE_P(IndirectUseTestInstantiation, IndirectUseTest, + ::testing::Values(true, false)); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index 6c5f185ed1..e0991fcb76 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -28,8 +28,9 @@ limitations under the License. namespace xla { -bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index, - HloInstruction* user, +bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user, const TuplePointsToAnalysis& points_to_analysis) { CHECK(user->IsUserOf(operand)) << "user: " << user->ToString() << " operand: " << operand->ToString(); diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h index 410a7b1b51..52de282ca6 100644 --- a/tensorflow/compiler/xla/service/liveness_util.h +++ b/tensorflow/compiler/xla/service/liveness_util.h @@ -32,8 +32,9 @@ namespace xla { // 'operand'. Returns false otherwise. // // REQUIRES: 'operand' is an operand of 'user'. -bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index, - HloInstruction* user, +bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user, const TuplePointsToAnalysis& points_to_analysis); // Returns true if 'user' (at 'user_index') can share a buffer with its operand -- GitLab From 7888d8c318e6b0d54d3bfdb44dde47643256a728 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Wed, 26 Apr 2017 14:29:09 -0800 Subject: [PATCH 023/697] Disable the requirement for a shape inference function. Change: 154353632 --- tensorflow/core/grappler/costs/graph_properties.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 4edcdccdfe..06e91af2c2 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -26,6 +26,7 @@ namespace grappler { Status GraphProperties::InferStatically() { Graph graph(OpRegistry::Global()); ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); + shape_refiner.set_require_shape_inference_fns(false); ImportGraphDefOptions options; Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner); TF_RETURN_IF_ERROR(s); -- GitLab From b6c105d23a12b844784a7de37abbadf2bb183ff1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Apr 2017 14:41:33 -0800 Subject: [PATCH 024/697] Check if memory_sequence_length is not None before converting tensor Change: 154355223 --- tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 04b38159bb..2e2b2ebe60 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -73,8 +73,9 @@ def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined): """ memory = nest.map_structure( lambda m: ops.convert_to_tensor(m, name="memory"), memory) - memory_sequence_length = ops.convert_to_tensor( - memory_sequence_length, name="memory_sequence_length") + if memory_sequence_length is not None: + memory_sequence_length = ops.convert_to_tensor( + memory_sequence_length, name="memory_sequence_length") if check_inner_dims_defined: def _check_dims(m): if not m.get_shape()[2:].is_fully_defined(): -- GitLab From 626f5657686cea341ba2200ed0a50479c0a8f3fb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Apr 2017 15:39:11 -0800 Subject: [PATCH 025/697] [XLA] LiteralUtil::EachCellAsString : Pass in callback function as reference. Change: 154362604 --- tensorflow/compiler/xla/literal_util.cc | 5 ++--- tensorflow/compiler/xla/literal_util.h | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 7091c324d1..0286b0817c 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -680,9 +680,8 @@ void TransposeLiteralInternal(const Literal& original, /* static */ void LiteralUtil::EachCellAsString( const Literal& literal, - std::function indices, - const string& value)> - per_cell) { + const std::function indices, + const string& value)>& per_cell) { if (ShapeUtil::Rank(literal.shape()) == 1) { for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { per_cell({i0}, GetAsString(literal, {i0})); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 21bb2e46cf..ef78b819e3 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -257,9 +257,8 @@ class LiteralUtil { // like representation in a protobuf). static void EachCellAsString( const Literal& literal, - std::function indices, - const string& value)> - per_cell); + const std::function indices, + const string& value)>& per_cell); template static void EachCell( const Literal& literal, -- GitLab From 21990e2a7c2f563202e95e4e5d78c8b599957f33 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Apr 2017 15:39:55 -0800 Subject: [PATCH 026/697] Time out after 30 minutes when waiting for the session to be ready. Change: 154362697 --- tensorflow/python/estimator/estimator_test.py | 2 ++ tensorflow/python/training/monitored_session.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 3b46db59e3..f70c285f04 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -540,6 +540,8 @@ class EstimatorTrainTest(test.TestCase): # Mocking the SessionManager.wait_for_session, so that worker doesn't wait # for chief. def get_initialized_session(*args, **kwargs): + # Session doesn't take 'max_wait_secs' argument. + kwargs.pop('max_wait_secs', None) scaffold = training.Scaffold().finalize() sess = session.Session(*args, **kwargs) sess.run(scaffold.init_op) diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index 6d6128d207..4c81af56ad 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -422,7 +422,9 @@ class WorkerSessionCreator(SessionCreator): def create_session(self): self._scaffold.finalize() return self._get_session_manager().wait_for_session( - self._master, config=self._config) + self._master, config=self._config, + max_wait_secs=30 * 60 # Wait up to 30 mins for the session to be ready. + ) class _MonitoredSession(object): -- GitLab From 7314ba8f5d8419892b23a5c89fd809c8d86fdcb8 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 26 Apr 2017 16:00:30 -0800 Subject: [PATCH 027/697] Test that Graphs and Operations are gc'd despite cyclic references. I ran into this working on a patch where I added a __del__ method to Operation, which caused garbage collection to be disabled for cyclic references. This triggered a failure in PyOpTest.testCleanup but was very hard to debug. This test should more clearly guard against making this mistake. Change: 154365215 --- tensorflow/python/framework/ops_test.py | 30 +++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 2cff66dfb7..06d03121a0 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -18,7 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gc +import weakref + from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.client import session from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as pydev @@ -1298,6 +1302,32 @@ class GraphTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError): g.as_graph_element(NonConvertibleObj()) + # Regression test against creating custom __del__ functions in classes + # involved in cyclic references, e.g. Graph and Operation. (Python won't gc + # cycles that require calling a __del__ method, because the __del__ method can + # theoretically increase the object's refcount to "save" it from gc, and any + # already-deleted objects in the cycle would have be to restored.) + def testGarbageCollected(self): + # Create a graph we can delete and a weak reference to monitor if it's gc'd + g = ops.Graph() + g_ref = weakref.ref(g) + # Create some ops + with g.as_default(): + a = constant_op.constant(2.0) + b = constant_op.constant(3.0) + c = math_ops.add(a, b) + # Create a session we can delete + with session.Session(graph=g) as sess: + sess.run(c) + # Delete all references and trigger gc + del g + del a + del b + del c + del sess + gc.collect() + self.assertIsNone(g_ref()) + class AttrScopeTest(test_util.TensorFlowTestCase): -- GitLab From 1e3e5d424eaa6332314f8ad1d54089eb0f9e02e7 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 26 Apr 2017 16:12:34 -0800 Subject: [PATCH 028/697] Refactor Keras layers to rely on core TF layers. API change: for users of custom Keras layers built using `tf.contrib.keras`, the method `add_weight` of the Keras base layer has now a new API (synced with the main Keras GitHub repo). Change: 154366685 --- tensorflow/contrib/keras/BUILD | 1 + .../contrib/keras/python/keras/backend.py | 40 +- .../keras/python/keras/engine/topology.py | 348 +++++------------ .../python/keras/engine/topology_test.py | 4 +- .../keras/layers/advanced_activations.py | 2 +- .../python/keras/layers/convolutional.py | 14 +- .../keras/layers/convolutional_recurrent.py | 6 +- .../contrib/keras/python/keras/layers/core.py | 64 +-- .../keras/python/keras/layers/core_test.py | 13 +- .../keras/python/keras/layers/embeddings.py | 2 +- .../keras/python/keras/layers/local.py | 8 +- .../keras/python/keras/layers/merge.py | 1 + .../python/keras/layers/normalization.py | 8 +- .../keras/python/keras/layers/recurrent.py | 20 +- .../contrib/keras/python/keras/models.py | 9 + .../seq2seq/python/ops/attention_wrapper.py | 4 +- .../seq2seq/python/ops/basic_decoder.py | 2 +- .../seq2seq/python/ops/beam_search_decoder.py | 2 +- .../contrib/seq2seq/python/ops/helper.py | 2 +- tensorflow/python/layers/base.py | 363 ++++++++++++------ tensorflow/python/layers/base_test.py | 84 ++-- tensorflow/python/layers/convolutional.py | 70 ++-- tensorflow/python/layers/core.py | 28 +- tensorflow/python/layers/core_test.py | 12 +- tensorflow/python/layers/normalization.py | 51 ++- tensorflow/python/layers/pooling.py | 6 +- tensorflow/python/ops/rnn_cell_impl.py | 2 +- 27 files changed, 568 insertions(+), 598 deletions(-) diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD index 5166ba37a3..b1b8fc49b6 100644 --- a/tensorflow/contrib/keras/BUILD +++ b/tensorflow/contrib/keras/BUILD @@ -119,6 +119,7 @@ py_library( "//tensorflow/python:gradients", "//tensorflow/python:image_ops", "//tensorflow/python:init_ops", + "//tensorflow/python:layers", "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn", diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py index e52b23843a..905ef13e14 100644 --- a/tensorflow/contrib/keras/python/keras/backend.py +++ b/tensorflow/contrib/keras/python/keras/backend.py @@ -21,7 +21,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from collections import defaultdict import json import os import warnings @@ -245,17 +244,40 @@ def set_image_data_format(data_format): def get_uid(prefix=''): - global _GRAPH_UID_DICTS # pylint: disable=global-variable-not-assigned - graph = ops.get_default_graph() - if graph not in _GRAPH_UID_DICTS: - _GRAPH_UID_DICTS[graph] = defaultdict(int) - _GRAPH_UID_DICTS[graph][prefix] += 1 - return _GRAPH_UID_DICTS[graph][prefix] + """Associates a string prefix with an integer counter in a TensorFlow graph. + + Arguments: + prefix: String prefix to index. + + Returns: + Unique integer ID. + + Example: + + ``` + >>> get_uid('dense') + 1 + >>> get_uid('dense') + 2 + ``` + """ + layer_name_uids_collection = ops.get_collection('LAYER_NAME_UIDS') + if not layer_name_uids_collection: + layer_name_uids = {} + ops.add_to_collection('LAYER_NAME_UIDS', layer_name_uids) + else: + layer_name_uids = layer_name_uids_collection[0] + if prefix not in layer_name_uids: + layer_name_uids[prefix] = 1 + else: + layer_name_uids[prefix] += 1 + return layer_name_uids[prefix] def reset_uids(): - global _GRAPH_UID_DICTS - _GRAPH_UID_DICTS = {} + layer_name_uids_collection = ops.get_collection_ref('LAYER_NAME_UIDS') + if layer_name_uids_collection: + layer_name_uids_collection.pop() def clear_session(): diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py index 7848e5982d..0336fc4bf4 100644 --- a/tensorflow/contrib/keras/python/keras/engine/topology.py +++ b/tensorflow/contrib/keras/python/keras/engine/topology.py @@ -29,11 +29,12 @@ import numpy as np from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.contrib.keras.python.keras import backend as K -from tensorflow.contrib.keras.python.keras import initializers from tensorflow.contrib.keras.python.keras.utils import conv_utils from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as tf_base_layers from tensorflow.python.util import tf_inspect @@ -207,7 +208,7 @@ class Node(object): } -class Layer(object): +class Layer(tf_base_layers.Layer): """Abstract base layer class. # Properties @@ -276,24 +277,6 @@ class Layer(object): """ def __init__(self, **kwargs): - self.input_spec = None - self.supports_masking = False - - # These properties will be set upon call of self.build() - self._trainable_weights = [] - self._non_trainable_weights = [] - self._constraints = {} # dict {tensor: constraint instance} - self._losses = [] - self._updates = [] - self._per_input_losses = {} - self._per_input_updates = {} - self._built = False - - # These lists will be filled via successive calls - # to self._add_inbound_node(). - self.inbound_nodes = [] - self.outbound_nodes = [] - # These properties should be set by the user via keyword arguments. # note that 'dtype', 'input_shape' and 'batch_input_shape' # are only applicable to input layers: do not pass these keywords @@ -306,18 +289,38 @@ class Layer(object): 'name', 'trainable', 'weights', - 'input_dtype', # legacy } + # Validate optional keyword arguments. for kwarg in kwargs: if kwarg not in allowed_kwargs: raise TypeError('Keyword argument not understood:', kwarg) + + # Get layer name. name = kwargs.get('name') - if not name: - prefix = self.__class__.__name__ - name = _to_snake_case(prefix) + '_' + str(K.get_uid(prefix)) - self.name = name - self.trainable = kwargs.get('trainable', True) + # Get `trainable` status. + trainable = kwargs.get('trainable', True) + + # Get `dtype`. + dtype = kwargs.get('dtype') + if dtype is None: + dtype = K.floatx() + + # Call super, which will set all properties common to Keras layers + # and core TF layers. + super(Layer, self).__init__(name=name, dtype=dtype, trainable=trainable) + + # Add properties that are Keras-only for now. + self.input_spec = None + self.supports_masking = False + self._constraints = {} # dict {tensor: constraint instance} + + # These lists will be filled via successive calls + # to self._add_inbound_node(). + self.inbound_nodes = [] + self.outbound_nodes = [] + + # Manage input shape information if passed. if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: # In this case we will later create an input layer # to insert before the current layer @@ -331,35 +334,12 @@ class Layer(object): batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) self.batch_input_shape = batch_input_shape - # Set dtype. - dtype = kwargs.get('dtype') - if dtype is None: - dtype = kwargs.get('input_dtype') - if dtype is None: - dtype = K.floatx() - self.dtype = dtype - + # Manage initial weight values if passed. if 'weights' in kwargs: self._initial_weights = kwargs['weights'] else: self._initial_weights = None - @property - def losses(self): - return self._losses - - @property - def updates(self): - return self._updates - - @property - def built(self): - return self._built - - @built.setter - def built(self, value): - self._built = value - @property def constraints(self): return self._constraints @@ -368,63 +348,37 @@ class Layer(object): def constraints(self, constraints): self._constraints = constraints - @property - def trainable_weights(self): - trainable = getattr(self, 'trainable', True) - if trainable: - return self._trainable_weights - else: - return [] - - @trainable_weights.setter - def trainable_weights(self, weights): - self._trainable_weights = weights - - @property - def non_trainable_weights(self): - trainable = getattr(self, 'trainable', True) - if not trainable: - return self._trainable_weights + self._non_trainable_weights - else: - return self._non_trainable_weights - - @non_trainable_weights.setter - def non_trainable_weights(self, weights): - self._non_trainable_weights = weights - def add_weight(self, + name, shape, - initializer, - name=None, - trainable=True, + dtype=None, + initializer=None, regularizer=None, + trainable=True, constraint=None): """Adds a weight variable to the layer. Arguments: + name: String, the name for the weight variable. shape: The shape tuple of the weight. + dtype: The dtype of the weight. initializer: An Initializer instance (callable). - name: String, the name for the weight variable. + regularizer: An optional Regularizer instance. trainable: A boolean, whether the weight should be trained via backprop or not (assuming that the layer itself is also trainable). - regularizer: An optional Regularizer instance. constraint: An optional Constraint instance. Returns: The created weight variable. """ - shape = tuple(tensor_shape.TensorShape(shape).as_list()) - initializer = initializers.get(initializer) - weight = K.variable(initializer(shape), dtype=K.floatx(), name=name) - if regularizer is not None: - self.add_loss(regularizer(weight)) + if dtype is None: + dtype = K.floatx() + weight = self.add_variable( + name, shape, dtype=dtype, + initializer=initializer, regularizer=regularizer, trainable=trainable) if constraint is not None: self.constraints[weight] = constraint - if trainable: - self._trainable_weights.append(weight) - else: - self._non_trainable_weights.append(weight) return weight def assert_input_compatibility(self, inputs): @@ -554,66 +508,46 @@ class Layer(object): """ if isinstance(inputs, list): inputs = inputs[:] + + # Raise exceptions in case the input is not compatible + # with the input_spec set at build time. + # TODO(fchollet): call after the layer is built, too. + self.assert_input_compatibility(inputs) + + # Handle mask propagation. + previous_mask = _collect_previous_mask(inputs) + user_kwargs = copy.copy(kwargs) + if not _is_all_none(previous_mask): + # The previous layer generated a mask. + if 'mask' in tf_inspect.getargspec(self.call).args: + if 'mask' not in kwargs: + # If mask is explicitly passed to __call__, + # we should override the default mask. + kwargs['mask'] = previous_mask + + # Actually call the layer (optionally building it). + output = super(Layer, self).__call__(inputs, **kwargs) + + # Handle mask computation. with K.name_scope(self.name): - # Handle laying building (weight creating, input spec locking). - if not self.built: - # Raise exceptions in case the input is not compatible - # with the input_spec specified in the layer constructor. - self.assert_input_compatibility(inputs) - - # Collect input shapes to build layer. - input_shapes = [] - for x_elem in _to_list(inputs): - input_shapes.append(K.int_shape(x_elem)) - if len(input_shapes) == 1: - self.build(input_shapes[0]) - else: - self.build(input_shapes) - self.built = True - - # Load weights that were specified at layer instantiation. - if self._initial_weights is not None: - self.set_weights(self._initial_weights) - - # Raise exceptions in case the input is not compatible - # with the input_spec set at build time. - self.assert_input_compatibility(inputs) - - # Handle mask propagation. - previous_mask = _collect_previous_mask(inputs) - user_kwargs = copy.copy(kwargs) - if not _is_all_none(previous_mask): - # The previous layer generated a mask. - if 'mask' in tf_inspect.getargspec(self.call).args: - if 'mask' not in kwargs: - # If mask is explicitly passed to __call__, - # we should override the default mask. - kwargs['mask'] = previous_mask - - # Actually call the layer, collecting output(s), mask(s), and shape(s). - output = self.call(inputs, **kwargs) output_mask = self.compute_mask(inputs, previous_mask) - # Add an inbound node to the layer, so that it keeps track - # of the call and of all new variables created during the call. - # This also updates the layer history of the output tensor(s). - # If the input tensor(s) had not previous Keras history, - # this does nothing. - self._add_inbound_node( - input_tensors=inputs, - output_tensors=output, - input_masks=previous_mask, - output_masks=output_mask, - arguments=user_kwargs) - - # Apply activity regularizer if any: - if hasattr( - self, - 'activity_regularizer') and self.activity_regularizer is not None: - regularization_losses = [ - self.activity_regularizer(x) for x in _to_list(output) - ] - self.add_loss(regularization_losses, _to_list(inputs)) + # Add an inbound node to the layer, so that it keeps track + # of the call and of all new variables created during the call. + # This also updates the layer history of the output tensor(s). + # If the input tensor(s) had not previous Keras history, + # this does nothing. + self._add_inbound_node( + input_tensors=inputs, + output_tensors=output, + input_masks=previous_mask, + output_masks=output_mask, + arguments=user_kwargs) + + # Optionally load weight values that were specified at layer instantiation. + if hasattr(self, '_initial_weights') and self._initial_weights is not None: + self.set_weights(self._initial_weights) + del self._initial_weights return output def _add_inbound_node(self, @@ -959,14 +893,14 @@ class Layer(object): @property def input_shape(self): - """Retrieves the input shape tuple(s) of a layer. + """Retrieves the input shape(s) of a layer. Only applicable if the layer has exactly one inbound node, i.e. if it is connected to one incoming layer. Returns: - Input shape tuple - (or list of input shape tuples, one tuple per input tensor). + Input shape, as `TensorShape` + (or list of `TensorShape`, one tuple per input tensor). Raises: AttributeError: if the layer is connected to @@ -997,14 +931,14 @@ class Layer(object): @property def output_shape(self): - """Retrieves the output shape tuple(s) of a layer. + """Retrieves the output shape(s) of a layer. Only applicable if the layer has one inbound node, or if all inbound nodes have the same output shape. Returns: - Output shape tuple - (or list of input shape tuples, one tuple per output tensor). + Output shape, as `TensorShape` + (or list of `TensorShape`, one tuple per output tensor). Raises: AttributeError: if the layer is connected to @@ -1033,94 +967,6 @@ class Layer(object): 'Use `get_output_shape_at(node_index)` ' 'instead.') - def add_loss(self, losses, inputs=None): - """Add losses to the layer. - - The loss may potentially be conditional on some inputs tensors, - for instance activity losses are conditional on the layer's inputs. - - Arguments: - losses: loss tensor or list of loss tensors - to add to the layer. - inputs: input tensor or list of inputs tensors to mark - the losses as conditional on these inputs. - If None is passed, the loss is assumed unconditional - (e.g. L2 weight regularization, which only depends - on the layer's weights variables, not on any inputs tensors). - """ - if losses is None or losses == []: # pylint: disable=g-explicit-bool-comparison - return - # Update self.losses - losses = _to_list(losses) - if hasattr(self, '_losses'): - self._losses += losses - # Update self._per_input_updates - if inputs == []: # pylint: disable=g-explicit-bool-comparison - inputs = None - if inputs is not None: - inputs_hash = _object_list_uid(inputs) - else: - # Updates indexed by None are unconditional - # rather than input-dependent - inputs_hash = None - if inputs_hash not in self._per_input_losses: - self._per_input_losses[inputs_hash] = [] - self._per_input_losses[inputs_hash] += losses - - def add_update(self, updates, inputs=None): - """Add updates to the layer. - - The updates may potentially be conditional on some inputs tensors, - for instance batch norm updates are conditional on the layer's inputs. - - Arguments: - updates: update op or list of update ops - to add to the layer. - inputs: input tensor or list of inputs tensors to mark - the updates as conditional on these inputs. - If None is passed, the updates are assumed unconditional. - """ - if updates is None or updates == []: # pylint: disable=g-explicit-bool-comparison - return - # Update self.updates - updates = _to_list(updates) - if hasattr(self, '_updates'): - self._updates += updates - # Update self._per_input_updates - if inputs == []: # pylint: disable=g-explicit-bool-comparison - inputs = None - if inputs is not None: - inputs_hash = _object_list_uid(inputs) - else: - # Updates indexed by None are unconditional - # rather than input-dependent - inputs_hash = None - if inputs_hash not in self._per_input_updates: - self._per_input_updates[inputs_hash] = [] - self._per_input_updates[inputs_hash] += updates - - def get_updates_for(self, inputs): - if inputs is not None: - inputs_hash = _object_list_uid(inputs) - else: - inputs_hash = None - if inputs_hash in self._per_input_updates: - return self._per_input_updates[inputs_hash] - return [] - - def get_losses_for(self, inputs): - if inputs is not None: - inputs_hash = _object_list_uid(inputs) - else: - inputs_hash = None - if inputs_hash in self._per_input_losses: - return self._per_input_losses[inputs_hash] - return [] - - @property - def weights(self): - return self.trainable_weights + self.non_trainable_weights - def set_weights(self, weights): """Sets the weights of the layer, from Numpy arrays. @@ -1254,9 +1100,12 @@ class InputLayer(Layer): if not name: prefix = 'input' name = prefix + '_' + str(K.get_uid(prefix)) + if not dtype: + if input_tensor is None: + dtype = K.floatx() + else: + dtype = K.dtype(input_tensor) super(InputLayer, self).__init__(dtype=dtype, name=name) - - self.trainable = False self.built = True self.sparse = sparse @@ -1284,15 +1133,7 @@ class InputLayer(Layer): batch_input_shape = (batch_size,) + tuple(input_shape) else: batch_input_shape = tuple(batch_input_shape) - - if not dtype: - if input_tensor is None: - dtype = K.floatx() - else: - dtype = K.dtype(input_tensor) - self.batch_input_shape = batch_input_shape - self.dtype = dtype if input_tensor is None: self.is_placeholder = True @@ -1446,12 +1287,19 @@ class Container(Layer): prefix = self.__class__.__name__.lower() name = prefix + '_' + str(K.get_uid(prefix)) self.name = name - self.supports_masking = False self.trainable = True self._per_input_losses = {} self._per_input_updates = {} + # The following properties are not actually used by Keras; + # they exist for compatibility with TF. + self._updates = [] + self._scope = None + self._reuse = None + self._base_name = name + self._graph = ops.get_default_graph() + # Container-specific properties. if isinstance(inputs, (list, tuple)): self.inputs = list(inputs) # Tensor or list of tensors. diff --git a/tensorflow/contrib/keras/python/keras/engine/topology_test.py b/tensorflow/contrib/keras/python/keras/engine/topology_test.py index eb095b14a9..531ed4be3e 100644 --- a/tensorflow/contrib/keras/python/keras/engine/topology_test.py +++ b/tensorflow/contrib/keras/python/keras/engine/topology_test.py @@ -490,8 +490,8 @@ class TopologyConstructionTest(test.TestCase): m, n = model([j, k]) tf_model = keras.models.Model([j, k], [m, n]) - j_tf = array_ops.placeholder(dtype=dtypes.float32) - k_tf = array_ops.placeholder(dtype=dtypes.float32) + j_tf = array_ops.placeholder(dtype=dtypes.float32, shape=(None, 32)) + k_tf = array_ops.placeholder(dtype=dtypes.float32, shape=(None, 32)) m_tf, n_tf = tf_model([j_tf, k_tf]) self.assertListEqual(m_tf.get_shape().as_list(), [None, 64]) self.assertListEqual(n_tf.get_shape().as_list(), [None, 5]) diff --git a/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py b/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py index b3abfc29d2..2c957ece44 100644 --- a/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py +++ b/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py @@ -120,7 +120,7 @@ class PReLU(Layer): param_shape[i - 1] = 1 self.param_broadcast[i - 1] = True self.alpha = self.add_weight( - param_shape, + shape=param_shape, name='alpha', initializer=self.alpha_initializer, regularizer=self.alpha_regularizer, diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional.py b/tensorflow/contrib/keras/python/keras/layers/convolutional.py index 38b8fe66a3..16f49c3390 100644 --- a/tensorflow/contrib/keras/python/keras/layers/convolutional.py +++ b/tensorflow/contrib/keras/python/keras/layers/convolutional.py @@ -140,14 +140,14 @@ class _Conv(Layer): kernel_shape = self.kernel_size + (input_dim, self.filters) self.kernel = self.add_weight( - kernel_shape, + shape=kernel_shape, initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight( - (self.filters,), + shape=(self.filters,), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, @@ -734,14 +734,14 @@ class Conv2DTranspose(Conv2D): kernel_shape = self.kernel_size + (self.filters, input_dim) self.kernel = self.add_weight( - kernel_shape, + shape=kernel_shape, initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight( - (self.filters,), + shape=(self.filters,), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, @@ -949,13 +949,13 @@ class SeparableConv2D(Conv2D): self.filters) self.depthwise_kernel = self.add_weight( - depthwise_kernel_shape, + shape=depthwise_kernel_shape, initializer=self.depthwise_initializer, name='depthwise_kernel', regularizer=self.depthwise_regularizer, constraint=self.depthwise_constraint) self.pointwise_kernel = self.add_weight( - pointwise_kernel_shape, + shape=pointwise_kernel_shape, initializer=self.pointwise_initializer, name='pointwise_kernel', regularizer=self.pointwise_regularizer, @@ -963,7 +963,7 @@ class SeparableConv2D(Conv2D): if self.use_bias: self.bias = self.add_weight( - (self.filters,), + shape=(self.filters,), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py b/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py index 4d8ef44da7..30325b7148 100644 --- a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py @@ -369,20 +369,20 @@ class ConvLSTM2D(ConvRecurrent2D): recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4) self.kernel = self.add_weight( - kernel_shape, + shape=kernel_shape, initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.recurrent_kernel = self.add_weight( - recurrent_kernel_shape, + shape=recurrent_kernel_shape, initializer=self.recurrent_initializer, name='recurrent_kernel', regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) if self.use_bias: self.bias = self.add_weight( - (self.filters * 4,), + shape=(self.filters * 4,), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, diff --git a/tensorflow/contrib/keras/python/keras/layers/core.py b/tensorflow/contrib/keras/python/keras/layers/core.py index 32ada176a4..7a9e0d1736 100644 --- a/tensorflow/contrib/keras/python/keras/layers/core.py +++ b/tensorflow/contrib/keras/python/keras/layers/core.py @@ -34,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserializ from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_dump from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_load from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import core as tf_core_layers from tensorflow.python.util import tf_inspect @@ -643,7 +644,7 @@ class Lambda(Layer): return cls(**config) -class Dense(Layer): +class Dense(tf_core_layers.Dense, Layer): """Just your regular densely-connected NN layer. `Dense` implements the operation: @@ -712,15 +713,20 @@ class Dense(Layer): **kwargs): if 'input_shape' not in kwargs and 'input_dim' in kwargs: kwargs['input_shape'] = (kwargs.pop('input_dim'),) - super(Dense, self).__init__(**kwargs) - self.units = units - self.activation = activations.get(activation) - self.use_bias = use_bias - self.kernel_initializer = initializers.get(kernel_initializer) - self.bias_initializer = initializers.get(bias_initializer) - self.kernel_regularizer = regularizers.get(kernel_regularizer) - self.bias_regularizer = regularizers.get(bias_regularizer) - self.activity_regularizer = regularizers.get(activity_regularizer) + + # Inheritance call order: + # 1) tf.layers.Dense, 2) keras.layers.Layer, 3) tf.layers.Layer + super(Dense, self).__init__( + units, + activation=activations.get(activation), + use_bias=use_bias, + kernel_initializer=initializers.get(kernel_initializer), + bias_initializer=initializers.get(bias_initializer), + kernel_regularizer=regularizers.get(kernel_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + **kwargs) + self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.input_spec = InputSpec(min_ndim=2) @@ -729,40 +735,12 @@ class Dense(Layer): def build(self, input_shape): assert len(input_shape) >= 2 input_dim = input_shape[-1] - - self.kernel = self.add_weight( - (input_dim, self.units), - initializer=self.kernel_initializer, - name='kernel', - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint) - if self.use_bias: - self.bias = self.add_weight( - (self.units,), - initializer=self.bias_initializer, - name='bias', - regularizer=self.bias_regularizer, - constraint=self.bias_constraint) - else: - self.bias = None + super(Dense, self).build(input_shape) self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim}) - self.built = True - - def call(self, inputs): - output = K.dot(inputs, self.kernel) - if self.use_bias: - output = K.bias_add(output, self.bias) - if self.activation is not None: - output = self.activation(output) - return output - - def _compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() - assert input_shape and len(input_shape) >= 2 - assert input_shape[-1] - output_shape = list(input_shape) - output_shape[-1] = self.units - return tensor_shape.TensorShape(output_shape) + if self.kernel_constraint: + self.constraints[self.kernel] = self.kernel_constraint + if self.use_bias and self.bias_constraint: + self.constraints[self.bias] = self.bias_constraint def get_config(self): config = { diff --git a/tensorflow/contrib/keras/python/keras/layers/core_test.py b/tensorflow/contrib/keras/python/keras/layers/core_test.py index d7aa8413bb..7066af0ef6 100644 --- a/tensorflow/contrib/keras/python/keras/layers/core_test.py +++ b/tensorflow/contrib/keras/python/keras/layers/core_test.py @@ -165,24 +165,23 @@ class CoreLayersTest(test.TestCase): 3, kernel_regularizer=keras.regularizers.l1(0.01), bias_regularizer='l1', - activity_regularizer='l2') - layer.build((None, 4)) - assert len(layer.losses) == 2 + activity_regularizer='l2', + name='dense_reg') layer(keras.backend.variable(np.ones((2, 4)))) - assert len(layer.losses) == 3 + self.assertEqual(3, len(layer.losses)) # Test constraints with self.test_session(): layer = keras.layers.Dense( 3, kernel_constraint='max_norm', bias_constraint='max_norm') - layer.build((None, 4)) - assert len(layer.constraints) == 2 + layer(keras.backend.variable(np.ones((2, 4)))) + self.assertEqual(2, len(layer.constraints)) def test_activity_regularization(self): with self.test_session(): layer = keras.layers.ActivityRegularization(l1=0.1) layer(keras.backend.variable(np.ones((2, 4)))) - assert len(layer.losses) == 1 + self.assertEqual(1, len(layer.losses)) if __name__ == '__main__': diff --git a/tensorflow/contrib/keras/python/keras/layers/embeddings.py b/tensorflow/contrib/keras/python/keras/layers/embeddings.py index 12a2ce39eb..bc0bae67d0 100644 --- a/tensorflow/contrib/keras/python/keras/layers/embeddings.py +++ b/tensorflow/contrib/keras/python/keras/layers/embeddings.py @@ -116,7 +116,7 @@ class Embedding(Layer): def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() self.embeddings = self.add_weight( - (self.input_dim, self.output_dim), + shape=(self.input_dim, self.output_dim), initializer=self.embeddings_initializer, name='embeddings', regularizer=self.embeddings_regularizer, diff --git a/tensorflow/contrib/keras/python/keras/layers/local.py b/tensorflow/contrib/keras/python/keras/layers/local.py index d96ccc4a63..863674c1cb 100644 --- a/tensorflow/contrib/keras/python/keras/layers/local.py +++ b/tensorflow/contrib/keras/python/keras/layers/local.py @@ -130,14 +130,14 @@ class LocallyConnected1D(Layer): self.kernel_shape = (output_length, self.kernel_size[0] * input_dim, self.filters) self.kernel = self.add_weight( - self.kernel_shape, + shape=self.kernel_shape, initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight( - (output_length, self.filters), + shape=(output_length, self.filters), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, @@ -340,14 +340,14 @@ class LocallyConnected2D(Layer): output_row * output_col, self.kernel_size[0] * self.kernel_size[1] * input_filter, self.filters) self.kernel = self.add_weight( - self.kernel_shape, + shape=self.kernel_shape, initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight( - (output_row, output_col, self.filters), + shape=(output_row, output_col, self.filters), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, diff --git a/tensorflow/contrib/keras/python/keras/layers/merge.py b/tensorflow/contrib/keras/python/keras/layers/merge.py index 7c6482d0de..25921979bd 100644 --- a/tensorflow/contrib/keras/python/keras/layers/merge.py +++ b/tensorflow/contrib/keras/python/keras/layers/merge.py @@ -87,6 +87,7 @@ class _Merge(Layer): raise ValueError('A merge layer should be called ' 'on a list of at least 2 inputs. ' 'Got ' + str(len(input_shape)) + ' inputs.') + input_shape = [tensor_shape.TensorShape(s).as_list() for s in input_shape] batch_sizes = [s[0] for s in input_shape if s is not None] batch_sizes = set(batch_sizes) batch_sizes -= set([None]) diff --git a/tensorflow/contrib/keras/python/keras/layers/normalization.py b/tensorflow/contrib/keras/python/keras/layers/normalization.py index 9a0340aeaf..df77401aee 100644 --- a/tensorflow/contrib/keras/python/keras/layers/normalization.py +++ b/tensorflow/contrib/keras/python/keras/layers/normalization.py @@ -116,7 +116,7 @@ class BatchNormalization(Layer): if self.scale: self.gamma = self.add_weight( - shape, + shape=shape, name='gamma', initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, @@ -125,7 +125,7 @@ class BatchNormalization(Layer): self.gamma = None if self.center: self.beta = self.add_weight( - shape, + shape=shape, name='beta', initializer=self.beta_initializer, regularizer=self.beta_regularizer, @@ -133,12 +133,12 @@ class BatchNormalization(Layer): else: self.beta = None self.moving_mean = self.add_weight( - shape, + shape=shape, name='moving_mean', initializer=self.moving_mean_initializer, trainable=False) self.moving_variance = self.add_weight( - shape, + shape=shape, name='moving_variance', initializer=self.moving_variance_initializer, trainable=False) diff --git a/tensorflow/contrib/keras/python/keras/layers/recurrent.py b/tensorflow/contrib/keras/python/keras/layers/recurrent.py index 1ea1cb22d9..e608921add 100644 --- a/tensorflow/contrib/keras/python/keras/layers/recurrent.py +++ b/tensorflow/contrib/keras/python/keras/layers/recurrent.py @@ -493,20 +493,20 @@ class SimpleRNN(Recurrent): self.reset_states() self.kernel = self.add_weight( - (self.input_dim, self.units), + shape=(self.input_dim, self.units), name='kernel', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.recurrent_kernel = self.add_weight( - (self.units, self.units), + shape=(self.units, self.units), name='recurrent_kernel', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) if self.use_bias: self.bias = self.add_weight( - (self.units,), + shape=(self.units,), name='bias', initializer=self.bias_initializer, regularizer=self.bias_regularizer, @@ -723,13 +723,13 @@ class GRU(Recurrent): self.reset_states() self.kernel = self.add_weight( - (self.input_dim, self.units * 3), + shape=(self.input_dim, self.units * 3), name='kernel', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.recurrent_kernel = self.add_weight( - (self.units, self.units * 3), + shape=(self.units, self.units * 3), name='recurrent_kernel', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, @@ -737,9 +737,9 @@ class GRU(Recurrent): if self.use_bias: self.bias = self.add_weight( - (self.units * 3,), + shape=(self.units * 3,), name='bias', - initializer='zero', + initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) else: @@ -1039,13 +1039,13 @@ class LSTM(Recurrent): self.reset_states() self.kernel = self.add_weight( - (self.input_dim, self.units * 4), + shape=(self.input_dim, self.units * 4), name='kernel', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.recurrent_kernel = self.add_weight( - (self.units, self.units * 4), + shape=(self.units, self.units * 4), name='recurrent_kernel', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, @@ -1053,7 +1053,7 @@ class LSTM(Recurrent): if self.use_bias: self.bias = self.add_weight( - (self.units * 4,), + shape=(self.units * 4,), name='bias', initializer=self.bias_initializer, regularizer=self.bias_regularizer, diff --git a/tensorflow/contrib/keras/python/keras/models.py b/tensorflow/contrib/keras/python/keras/models.py index eb0996fa12..52456a4bb5 100644 --- a/tensorflow/contrib/keras/python/keras/models.py +++ b/tensorflow/contrib/keras/python/keras/models.py @@ -35,6 +35,7 @@ from tensorflow.contrib.keras.python.keras.engine.topology import Input from tensorflow.contrib.keras.python.keras.engine.topology import Layer from tensorflow.contrib.keras.python.keras.engine.training import Model from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.framework import ops # pylint: disable=g-import-not-at-top @@ -420,6 +421,14 @@ class Sequential(Model): name = prefix + str(K.get_uid(prefix)) self.name = name + # The following properties are not actually used by Keras; + # they exist for compatibility with TF's variable scoping mechanism. + self._updates = [] + self._scope = None + self._reuse = None + self._base_name = name + self._graph = ops.get_default_graph() + # Add to the model any layers passed to the constructor. if layers: for layer in layers: diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 2e2b2ebe60..6c181fe832 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -144,11 +144,11 @@ class _BaseAttentionMechanism(AttentionMechanism): name: Name to use when creating ops. """ if (query_layer is not None - and not isinstance(query_layer, layers_base._Layer)): # pylint: disable=protected-access + and not isinstance(query_layer, layers_base.Layer)): raise TypeError( "query_layer is not a Layer: %s" % type(query_layer).__name__) if (memory_layer is not None - and not isinstance(memory_layer, layers_base._Layer)): # pylint: disable=protected-access + and not isinstance(memory_layer, layers_base.Layer)): raise TypeError( "memory_layer is not a Layer: %s" % type(memory_layer).__name__) self._query_layer = query_layer diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py index e73a637027..6231a1fdf9 100644 --- a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py @@ -65,7 +65,7 @@ class BasicDecoder(decoder.Decoder): if not isinstance(helper, helper_py.Helper): raise TypeError("helper must be a Helper, received: %s" % type(helper)) if (output_layer is not None - and not isinstance(output_layer, layers_base._Layer)): # pylint: disable=protected-access + and not isinstance(output_layer, layers_base.Layer)): raise TypeError( "output_layer must be a Layer, received: %s" % type(output_layer)) self._cell = cell diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 289da8e6ae..2dbaf746ed 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -146,7 +146,7 @@ class BeamSearchDecoder(decoder.Decoder): if not isinstance(cell, core_rnn_cell.RNNCell): raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) if (output_layer is not None - and not isinstance(output_layer, layers_base._Layer)): # pylint: disable=protected-access + and not isinstance(output_layer, layers_base.Layer)): raise TypeError( "output_layer must be a Layer, received: %s" % type(output_layer)) self._cell = cell diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index e2d56063a2..d6c0527ad2 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -363,7 +363,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper): self._seed = seed if (next_input_layer is not None and not isinstance(next_input_layer, - layers_base._Layer)): # pylint: disable=protected-access + layers_base.Layer)): raise TypeError("next_input_layer must be a Layer, received: %s" % type(next_input_layer)) self._next_input_layer = next_input_layer diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index ff9a777f19..f6b816333e 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -38,7 +38,7 @@ from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect -class _Layer(object): +class Layer(object): """Base layer class. WARNING: Do not subclass this layer unless you know what you are doing: @@ -80,23 +80,27 @@ class _Layer(object): if kwarg not in allowed_kwargs: raise TypeError('Keyword argument not understood:', kwarg) - self._trainable = trainable - self._built = False - self._trainable_variables = [] - self._non_trainable_variables = [] + self.trainable = trainable + self.built = False + self._trainable_weights = [] + self._non_trainable_weights = [] self._updates = [] self._losses = [] self._reuse = kwargs.get('_reuse') self._graph = ops.get_default_graph() - self.dtype = dtype + self._per_input_losses = {} + self._per_input_updates = {} + self.dtype = dtypes.as_dtype(dtype).name - # Determine base name (non-unique). + # Determine layer name (non-unique). if isinstance(name, vs.VariableScope): base_name = name.name else: base_name = name + self.name = name if not name: base_name = _to_snake_case(self.__class__.__name__) + self.name = _unique_layer_name(base_name) self._base_name = base_name # Determine variable scope. @@ -106,45 +110,43 @@ class _Layer(object): else: self._scope = None - # Unique name is borrowed from scope to match variable names. - if self._scope is not None: - self._name = self._scope.name - else: - # No name available until we see a scope - self._name = None - - def __setattr__(self, name, value): - if hasattr(self, name): - # Only allow private attributes to be set more than once, under the - # convention that private attributes should only be set from inside - # the class. - # All attributes meant to be set several times should be set to private. - if name[0] != '_': - raise AttributeError('Read-only property cannot be set: %s' % name) - super(_Layer, self).__setattr__(name, value) + @property + def scope_name(self): + if not self._scope: + raise ValueError('No name available for layer scope because the layer "' + + self.name + '" has not been used yet. The scope name ' + + ' is determined the first time the layer instance is ' + + 'called. You must therefore call the layer before ' + + 'querying `scope_name`.') + return self._scope.name + + @property + def trainable_weights(self): + return self._trainable_weights if self.trainable else [] @property - def name(self): - if self._name is None: - raise ValueError( - 'No name available for layer because it has not been used yet.') - return self._name + def non_trainable_weights(self): + if self.trainable: + return self._non_trainable_weights + else: + return self._trainable_weights + self._non_trainable_weights @property def trainable_variables(self): - return self._trainable_variables if self.trainable else [] + return self.trainable_weights @property def non_trainable_variables(self): - return self._non_trainable_variables if self.trainable else self.variables + return self.non_trainable_weights @property - def trainable_weights(self): - return self.trainable_variables + def weights(self): + """Returns the list of all layer variables/weights. - @property - def non_trainable_weights(self): - return self.non_trainable_variables + Returns: + A list of variables. + """ + return self.trainable_weights + self.non_trainable_weights @property def variables(self): @@ -153,37 +155,141 @@ class _Layer(object): Returns: A list of variables. """ - return self._trainable_variables + self._non_trainable_variables + return self.weights @property def updates(self): return self._updates + def add_update(self, updates, inputs=None): + """Add update op(s), potentially dependent on layer inputs. + + Weight updates (for instance, the updates of the moving mean and variance + in a BatchNormalization layer) may be dependent on the inputs passed + when calling a layer. Hence, when reusing a same layer on + different inputs `a` and `b`, some entries in `layer.updates` may be + dependent on `a` and some on `b`. This method automatically keeps track + of dependencies. + + The `get_updates_for` method allows to retrieve the updates relevant to a + specific set of inputs. + + Arguments: + updates: Update op, or list/tuple of update ops. + inputs: Optional input tensor(s) that the update(s) depend on. Must + match the `inputs` argument passed to the `__call__` method at the time + the updates are created. If `None` is passed, the updates are assumed + to be unconditional, and will apply across all dataflows of the layer. + """ + updates = _to_list(updates) + if not updates: + return + self._updates += updates + if inputs is not None: + inputs = _to_list(inputs) + if not inputs: + inputs = None + if inputs is not None: + # We compute an ID that uniquely identifies the list of tensors. + # This ID is order-sensitive. + inputs_hash = _object_list_uid(inputs) + else: + inputs_hash = None + if inputs_hash not in self._per_input_updates: + self._per_input_updates[inputs_hash] = [] + self._per_input_updates[inputs_hash] += updates + + def get_updates_for(self, inputs): + """Retrieves updates relevant to a specific set of inputs. + + Arguments: + inputs: Input tensor or list/tuple of input tensors. + Must match the `inputs` argument passed to the `__call__` method + at the time the updates were created. + If you pass `inputs=None`, unconditional updates are returned. + + Returns: + List of update ops of the layer that depend on `inputs`. + """ + if inputs is not None: + inputs = _to_list(inputs) + if not inputs: + inputs = None + if inputs is not None: + inputs_hash = _object_list_uid(inputs) + else: + inputs_hash = None + return self._per_input_updates.get(inputs_hash, []) + @property def losses(self): return self._losses - @property - def built(self): - return self._built + def add_loss(self, losses, inputs=None): + """Add loss tensor(s), potentially dependent on layer inputs. - @property - def trainable(self): - return self._trainable + Some losses (for instance, activity regularization losses) may be dependent + on the inputs passed when calling a layer. Hence, when reusing a same layer + on different inputs `a` and `b`, some entries in `layer.losses` may be + dependent on `a` and some on `b`. This method automatically keeps track + of dependencies. - @property - def weights(self): - """Returns the list of all layer variables/weights. + The `get_losses_for` method allows to retrieve the losses relevant to a + specific set of inputs. + + Arguments: + losses: Loss tensor, or list/tuple of tensors. + inputs: Optional input tensor(s) that the loss(es) depend on. Must + match the `inputs` argument passed to the `__call__` method at the time + the losses are created. If `None` is passed, the losses are assumed + to be unconditional, and will apply across all dataflows of the layer + (e.g. weight regularization losses). + """ + losses = _to_list(losses) + if not losses: + return + self._losses += losses + if inputs is not None: + inputs = _to_list(inputs) + if not inputs: + inputs = None + if inputs is not None: + # We compute an ID that uniquely identifies the list of tensors. + # This ID is order-sensitive. + inputs_hash = _object_list_uid(inputs) + else: + inputs_hash = None + if inputs_hash not in self._per_input_losses: + self._per_input_losses[inputs_hash] = [] + self._per_input_losses[inputs_hash] += losses + + def get_losses_for(self, inputs): + """Retrieves losses relevant to a specific set of inputs. + + Arguments: + inputs: Input tensor or list/tuple of input tensors. + Must match the `inputs` argument passed to the `__call__` + method at the time the losses were created. + If you pass `inputs=None`, unconditional losses are returned, + such as weight regularization losses. Returns: - A list of variables. + List of loss tensors of the layer that depend on `inputs`. """ - return self.variables + if inputs is not None: + inputs = _to_list(inputs) + if not inputs: + inputs = None + if inputs is not None: + inputs_hash = _object_list_uid(inputs) + else: + inputs_hash = None + return self._per_input_losses.get(inputs_hash, []) def build(self, _): """Creates the variables of the layer. """ - self._built = True + self.built = True def call(self, inputs, **kwargs): """The logic of the layer lives here. @@ -217,9 +323,18 @@ class _Layer(object): """ raise NotImplementedError - def _add_variable(self, name, shape, dtype=None, - initializer=None, regularizer=None, trainable=True, - variable_getter=vs.get_variable): + def _set_scope(self, scope=None): + if self._scope is None: + # If constructed with _scope=None, lazy setting of scope. + if self._reuse: + self._scope = next(vs.variable_scope( + scope if scope is not None else self._base_name).gen) + else: + self._scope = next(vs.variable_scope( + scope, default_name=self._base_name).gen) + + def add_variable(self, name, shape, dtype=None, + initializer=None, regularizer=None, trainable=True): """Adds a new variable to the layer. Arguments: @@ -231,7 +346,6 @@ class _Layer(object): trainable: whether the variable should be part of the layer's "trainable_variables" (e.g. variables, biases) or "non_trainable_variables" (e.g. BatchNorm mean, stddev). - variable_getter: The getter to use for TensorFlow variables. Returns: The created variable. @@ -239,38 +353,43 @@ class _Layer(object): if dtype is None: dtype = self.dtype existing_variables = set(tf_variables.global_variables()) - variable = variable_getter(name, - shape=shape, - initializer=initializer, - dtype=dtype, - trainable=trainable and self.trainable) - # TODO(sguada) fix name = variable.op.name - if variable in existing_variables: - return variable - if regularizer: - # To match the behavior of tf.get_variable(), we only - # apply regularization if the variable is newly created. - if isinstance(variable, tf_variables.PartitionedVariable): - for v in variable: - with ops.colocate_with(v.op): - with ops.name_scope(name + '/Regularizer'): - regularization = regularizer(v) - if regularization is not None: - self._losses.append(regularization) - _add_elements_to_collection( - regularization, ops.GraphKeys.REGULARIZATION_LOSSES) - else: - with ops.colocate_with(variable.op): - with ops.name_scope(name + '/Regularizer'): - regularization = regularizer(variable) - if regularization is not None: - self._losses.append(regularization) - _add_elements_to_collection( - regularization, ops.GraphKeys.REGULARIZATION_LOSSES) + + self._set_scope(None) + + with vs.variable_scope(self._scope, + reuse=self.built or self._reuse) as scope: + with ops.name_scope(scope.original_name_scope): + variable = vs.get_variable(name, + shape=shape, + initializer=initializer, + dtype=dtypes.as_dtype(dtype), + trainable=trainable and self.trainable) + if variable in existing_variables: + return variable + if regularizer: + # To match the behavior of tf.get_variable(), we only + # apply regularization if the variable is newly created. + if isinstance(variable, tf_variables.PartitionedVariable): + for v in variable: + with ops.colocate_with(v.op): + with ops.name_scope(name + '/Regularizer'): + regularization = regularizer(v) + if regularization is not None: + self.add_loss(regularization) + _add_elements_to_collection( + regularization, ops.GraphKeys.REGULARIZATION_LOSSES) + else: + with ops.colocate_with(variable.op): + with ops.name_scope(name + '/Regularizer'): + regularization = regularizer(variable) + if regularization is not None: + self.add_loss(regularization) + _add_elements_to_collection( + regularization, ops.GraphKeys.REGULARIZATION_LOSSES) if trainable: - self._trainable_variables.append(variable) + self._trainable_weights.append(variable) else: - self._non_trainable_variables.append(variable) + self._non_trainable_weights.append(variable) return variable def __call__(self, inputs, *args, **kwargs): @@ -284,39 +403,17 @@ class _Layer(object): Returns: Output tensor(s). """ - scope = kwargs.pop('scope', None) - - # Define a custom getter to override tf.get_variable when creating layer - # variables. The current custom getter is nested by the variable scope. - def variable_getter(getter, name, shape, dtype=None, initializer=None, - regularizer=None, trainable=True, **getter_kwargs): - return self._add_variable( - name, shape, initializer=initializer, regularizer=regularizer, - dtype=dtype, trainable=trainable, - variable_getter=functools.partial(getter, **getter_kwargs)) - - if not self._built and self._scope is None: - # If constructed with _scope=None, lazy setting of scope. - if self._reuse: - self._scope = next(vs.variable_scope( - scope if scope is not None else self._base_name).gen) - else: - self._scope = next(vs.variable_scope( - scope, default_name=self._base_name).gen) - self._name = self._scope.name + self._set_scope(kwargs.pop('scope', None)) - # Build (if necessary) and call the layer, inside a variable - # scope. - with vs.variable_scope(self._scope, - reuse=True if self._built else self._reuse, - custom_getter=variable_getter) as scope: - # Ensure the Layer, if being reused, is working with inputs from - # the same graph as where it was created. - try: - ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph) # pylint: disable=protected-access - except ValueError as e: - raise ValueError("Inputs' and Layer's graphs are not the same: %s" % e) + # Ensure the Layer, if being reused, is working with inputs from + # the same graph as where it was created. + try: + ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph) # pylint: disable=protected-access + except ValueError as e: + raise ValueError('Input graph and Layer graph are not the same: %s' % e) + with vs.variable_scope(self._scope, + reuse=self.built or self._reuse) as scope: with ops.name_scope(scope.original_name_scope): if not self.built: input_list = [ @@ -327,7 +424,7 @@ class _Layer(object): self.build(input_shapes[0]) else: self.build(input_shapes) - self._built = True + self.built = True if 'scope' in tf_inspect.getargspec(self.call).args: kwargs['scope'] = scope outputs = self.call(inputs, *args, **kwargs) @@ -340,7 +437,7 @@ class _Layer(object): for output in output_list: with ops.name_scope('ActivityRegularizer'): activity_regularization = self.activity_regularizer(output) - self._losses.append(activity_regularization) + self.add_loss(activity_regularization) _add_elements_to_collection( activity_regularization, ops.GraphKeys.REGULARIZATION_LOSSES) @@ -419,3 +516,39 @@ def _add_elements_to_collection(elements, collections): for element in elements: if element not in collection_set: collection.append(element) + + +def _object_list_uid(object_list): + object_list = _to_list(object_list) + return ', '.join([str(abs(id(x))) for x in object_list]) + + +def _unique_layer_name(name): + """Makes a layer name (or arbitrary string) unique within a TensorFlow graph. + + Arguments: + name: String name to make unique. + + Returns: + Unique string name. + + Example: + + ``` + >>> _unique_layer_name('dense') + dense_1 + >>> _unique_layer_name('dense') + dense_2 + ``` + """ + layer_name_uids_collection = ops.get_collection('LAYER_NAME_UIDS') + if not layer_name_uids_collection: + layer_name_uids = {} + ops.add_to_collection('LAYER_NAME_UIDS', layer_name_uids) + else: + layer_name_uids = layer_name_uids_collection[0] + if name not in layer_name_uids: + layer_name_uids[name] = 1 + else: + layer_name_uids[name] += 1 + return name + '_' + str(layer_name_uids[name]) diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index 83ae1b6e83..9acf1c05e2 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -32,26 +32,24 @@ from tensorflow.python.platform import test class BaseLayerTest(test.TestCase): def testLayerProperties(self): - layer = base_layers._Layer(name='my_layer') + layer = base_layers.Layer(name='my_layer') self.assertListEqual(layer.variables, []) self.assertListEqual(layer.trainable_variables, []) self.assertListEqual(layer.non_trainable_variables, []) self.assertListEqual(layer.updates, []) self.assertListEqual(layer.losses, []) self.assertEqual(layer.built, False) - with self.assertRaisesRegexp(ValueError, 'not been used yet'): - _ = layer.name - layer = base_layers._Layer(name='my_layer', trainable=False) + layer = base_layers.Layer(name='my_layer', trainable=False) self.assertEqual(layer.trainable, False) def testAddWeight(self): with self.test_session(): - layer = base_layers._Layer(name='my_layer') + layer = base_layers.Layer(name='my_layer') # Test basic variable creation. - variable = layer._add_variable( + variable = layer.add_variable( 'my_var', [2, 2], initializer=init_ops.zeros_initializer()) - self.assertEqual(variable.name, 'my_var:0') + self.assertEqual(variable.name, 'my_layer/my_var:0') self.assertListEqual(layer.variables, [variable]) self.assertListEqual(layer.trainable_variables, [variable]) self.assertListEqual(layer.non_trainable_variables, []) @@ -60,8 +58,8 @@ class BaseLayerTest(test.TestCase): ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) # Test non-trainable variable creation. - # layer._add_variable should work even outside `build` and `call`. - variable_2 = layer._add_variable( + # layer.add_variable should work even outside `build` and `call`. + variable_2 = layer.add_variable( 'non_trainable_var', [2, 2], initializer=init_ops.zeros_initializer(), trainable=False) @@ -73,7 +71,7 @@ class BaseLayerTest(test.TestCase): # Test with regularizer. regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 - variable = layer._add_variable( + variable = layer.add_variable( 'reg_var', [2, 2], initializer=init_ops.zeros_initializer(), regularizer=regularizer) @@ -81,81 +79,70 @@ class BaseLayerTest(test.TestCase): def testGetVariable(self): with self.test_session(): - # From inside `build` and `call` it should be possible to use - # either tf.get_variable - class MyLayer(base_layers._Layer): + class MyLayer(base_layers.Layer): def build(self, input_shape): - self.my_var = variable_scope.get_variable( + self.my_var = self.add_variable( 'my_var', [2, 2], initializer=init_ops.zeros_initializer()) def call(self, inputs): - variable_scope.get_variable( - 'my_call_var', [2, 2], initializer=init_ops.zeros_initializer()) - return inputs + return inputs * 2 layer = MyLayer(name='my_layer') inputs = random_ops.random_uniform((5,), seed=1) layer.apply(inputs) layer.apply(inputs) self.assertListEqual([v.name for v in layer.variables], - ['my_layer/my_var:0', 'my_layer/my_call_var:0']) + ['my_layer/my_var:0']) # Creating a layer with no scope leads to lazy construction of # the scope at apply() time. It uses scope "/base_name" lazy_layer = MyLayer(_reuse=True) with variable_scope.variable_scope('new_scope'): - # This should attempt to reuse 'my_var' and 'my_call_var' in 'new_scope' + # This should attempt to reuse 'my_var' in 'new_scope' with self.assertRaisesRegexp( ValueError, r'new_scope/my_layer/my_var does not exist'): lazy_layer.apply(inputs) with variable_scope.variable_scope('my_layer'): variable_scope.get_variable('my_var', [2, 2]) - with self.assertRaisesRegexp( - ValueError, r'new_scope/my_layer/my_call_var does not exist'): - lazy_layer.apply(inputs) - with variable_scope.variable_scope('my_layer'): - variable_scope.get_variable('my_call_var', [2, 2]) + # Smoke test: it runs. lazy_layer.apply(inputs) # The variables were created outside of the Layer, and # reuse=True, so the Layer does not own them and they are not # stored in its collection. self.assertListEqual(lazy_layer.variables, []) - self.assertEqual(lazy_layer.name, 'new_scope/my_layer') + self.assertEqual(lazy_layer._scope.name, 'new_scope/my_layer') # Creating a layer with no scope leads to lazy construction of # the scope at apply() time. If 'scope' argument is passed to # apply(), it uses that scope when accessing variables. lazy_layer = MyLayer(_reuse=True) with variable_scope.variable_scope('new_scope') as new_scope: - # This should attempt to reuse 'my_var' and 'my_call_var' in 'new_scope' + # This should attempt to reuse 'my_var' in 'new_scope' with self.assertRaisesRegexp( ValueError, r'new_scope/my_var does not exist'): lazy_layer.apply(inputs, scope=new_scope) variable_scope.get_variable('my_var', [2, 2]) - with self.assertRaisesRegexp( - ValueError, r'new_scope/my_call_var does not exist'): - lazy_layer.apply(inputs, scope=new_scope) - variable_scope.get_variable('my_call_var', [2, 2]) + # Smoke test: it runs. lazy_layer.apply(inputs, scope=new_scope) # The variables were created outside of the Layer, and # reuse=True, so the Layer does not own them and they are not # stored in its collection. self.assertListEqual(lazy_layer.variables, []) - self.assertEqual(lazy_layer.name, 'new_scope') + self.assertEqual(lazy_layer._scope.name, 'new_scope') with ops.Graph().as_default(): inputs_ng = random_ops.random_uniform((5,), seed=1) with self.assertRaisesRegexp(ValueError, - r'graphs are not the same'): + r'graph are not the same'): layer.apply(inputs_ng) def testCall(self): - class MyLayer(base_layers._Layer): + class MyLayer(base_layers.Layer): def call(self, inputs): return math_ops.square(inputs) @@ -168,7 +155,7 @@ class BaseLayerTest(test.TestCase): def testDeepCopy(self): - class MyLayer(base_layers._Layer): + class MyLayer(base_layers.Layer): def call(self, inputs): return math_ops.square(inputs) @@ -184,9 +171,9 @@ class BaseLayerTest(test.TestCase): self.assertEqual(layer_copy._scope.name, layer._scope.name) self.assertEqual(layer_copy._graph, layer._graph) - def testNaming(self): + def testScopeNaming(self): - class PrivateLayer(base_layers._Layer): + class PrivateLayer(base_layers.Layer): def call(self, inputs): return None @@ -194,41 +181,42 @@ class BaseLayerTest(test.TestCase): inputs = random_ops.random_uniform((5,)) default_layer = PrivateLayer() _ = default_layer.apply(inputs) - self.assertEqual(default_layer.name, 'private_layer') + self.assertEqual(default_layer._scope.name, 'private_layer') default_layer1 = PrivateLayer() default_layer1.apply(inputs) - self.assertEqual(default_layer1.name, 'private_layer_1') + self.assertEqual(default_layer1._scope.name, 'private_layer_1') my_layer = PrivateLayer(name='my_layer') my_layer.apply(inputs) - self.assertEqual(my_layer.name, 'my_layer') + self.assertEqual(my_layer._scope.name, 'my_layer') my_layer1 = PrivateLayer(name='my_layer') my_layer1.apply(inputs) - self.assertEqual(my_layer1.name, 'my_layer_1') + self.assertEqual(my_layer1._scope.name, 'my_layer_1') my_layer2 = PrivateLayer(name='my_layer') my_layer2.apply(inputs) - self.assertEqual(my_layer2.name, 'my_layer_2') + self.assertEqual(my_layer2._scope.name, 'my_layer_2') # Name scope shouldn't affect names. with ops.name_scope('some_name_scope'): default_layer2 = PrivateLayer() default_layer2.apply(inputs) - self.assertEqual(default_layer2.name, 'private_layer_2') + self.assertEqual(default_layer2._scope.name, 'private_layer_2') my_layer3 = PrivateLayer(name='my_layer') my_layer3.apply(inputs) - self.assertEqual(my_layer3.name, 'my_layer_3') + self.assertEqual(my_layer3._scope.name, 'my_layer_3') other_layer = PrivateLayer(name='other_layer') other_layer.apply(inputs) - self.assertEqual(other_layer.name, 'other_layer') - # Variable scope gets added to names. + self.assertEqual(other_layer._scope.name, 'other_layer') + # Variable scope gets added to scope names. with variable_scope.variable_scope('var_scope'): default_layer_scoped = PrivateLayer() default_layer_scoped.apply(inputs) - self.assertEqual(default_layer_scoped.name, 'var_scope/private_layer') + self.assertEqual(default_layer_scoped._scope.name, + 'var_scope/private_layer') my_layer_scoped = PrivateLayer(name='my_layer') my_layer_scoped.apply(inputs) - self.assertEqual(my_layer_scoped.name, 'var_scope/my_layer') + self.assertEqual(my_layer_scoped._scope.name, 'var_scope/my_layer') my_layer_scoped1 = PrivateLayer(name='my_layer') my_layer_scoped1.apply(inputs) - self.assertEqual(my_layer_scoped1.name, 'var_scope/my_layer_1') + self.assertEqual(my_layer_scoped1._scope.name, 'var_scope/my_layer_1') if __name__ == '__main__': diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index 3b8959e210..50709bb51d 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -37,7 +37,7 @@ from tensorflow.python.layers import base from tensorflow.python.layers import utils -class _Conv(base._Layer): # pylint: disable=protected-access +class _Conv(base.Layer): """Abstract nD convolution layer (private, used as implementation base). This layer creates a convolution kernel that is convolved @@ -130,19 +130,19 @@ class _Conv(base._Layer): # pylint: disable=protected-access input_dim = input_shape[channel_axis].value kernel_shape = self.kernel_size + (input_dim, self.filters) - self.kernel = vs.get_variable('kernel', - shape=kernel_shape, - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - trainable=True, - dtype=self.dtype) + self.kernel = self.add_variable(name='kernel', + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + trainable=True, + dtype=self.dtype) if self.use_bias: - self.bias = vs.get_variable('bias', - shape=(self.filters,), - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - trainable=True, - dtype=self.dtype) + self.bias = self.add_variable(name='bias', + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + trainable=True, + dtype=self.dtype) else: self.bias = None @@ -814,27 +814,27 @@ class SeparableConv2D(Conv2D): self.depth_multiplier * input_dim, self.filters) - self.depthwise_kernel = vs.get_variable( - 'depthwise_kernel', + self.depthwise_kernel = self.add_variable( + name='depthwise_kernel', shape=depthwise_kernel_shape, initializer=self.depthwise_initializer, regularizer=self.depthwise_regularizer, trainable=True, dtype=self.dtype) - self.pointwise_kernel = vs.get_variable( - 'pointwise_kernel', + self.pointwise_kernel = self.add_variable( + name='pointwise_kernel', shape=pointwise_kernel_shape, initializer=self.pointwise_initializer, regularizer=self.pointwise_regularizer, trainable=True, dtype=self.dtype) if self.use_bias: - self.bias = vs.get_variable('bias', - shape=(self.filters,), - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - trainable=True, - dtype=self.dtype) + self.bias = self.add_variable(name='bias', + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + trainable=True, + dtype=self.dtype) else: self.bias = None @@ -1055,19 +1055,19 @@ class Conv2DTranspose(Conv2D): input_dim = input_shape[channel_axis] kernel_shape = self.kernel_size + (self.filters, input_dim) - self.kernel = vs.get_variable('kernel', - shape=kernel_shape, - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - trainable=True, - dtype=self.dtype) + self.kernel = self.add_variable(name='kernel', + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + trainable=True, + dtype=self.dtype) if self.use_bias: - self.bias = vs.get_variable('bias', - shape=(self.filters,), - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - trainable=True, - dtype=self.dtype) + self.bias = self.add_variable(name='bias', + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + trainable=True, + dtype=self.dtype) else: self.bias = None diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index b5846ae3d2..49f6499ca4 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -38,7 +38,7 @@ from tensorflow.python.layers import base from tensorflow.python.layers import utils -class Dense(base._Layer): # pylint: disable=protected-access +class Dense(base.Layer): """Densely-connected layer class. This layer implements the operation: @@ -115,19 +115,19 @@ class Dense(base._Layer): # pylint: disable=protected-access # weight of the layer. If the layer is not trainable # (self.trainable = False), the variable will not be added to # tf.trainable_variables(), and self.trainable_weights will be empty. - self.kernel = vs.get_variable('kernel', - shape=[input_shape[-1].value, self.units], - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - dtype=self.dtype, - trainable=True) + self.kernel = self.add_variable('kernel', + shape=[input_shape[-1].value, self.units], + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + dtype=self.dtype, + trainable=True) if self.use_bias: - self.bias = vs.get_variable('bias', - shape=[self.units,], - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - dtype=self.dtype, - trainable=True) + self.bias = self.add_variable('bias', + shape=[self.units,], + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + dtype=self.dtype, + trainable=True) else: self.bias = None @@ -219,7 +219,7 @@ def dense( return layer.apply(inputs) -class Dropout(base._Layer): # pylint: disable=protected-access +class Dropout(base.Layer): """Applies Dropout to the input. Dropout consists in randomly setting a fraction `rate` of input units to 0 diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index df650535d4..3939969159 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -44,16 +44,14 @@ class DenseTest(test.TestCase): self.assertEqual(dense.bias_regularizer, None) self.assertEqual(dense.activity_regularizer, None) self.assertEqual(dense.use_bias, True) - with self.assertRaisesRegexp(ValueError, 'not been used yet'): - _ = dense.name # Test auto-naming dense = core_layers.Dense(2, activation=nn_ops.relu) dense.apply(np.random.randn(0, 2)) - self.assertEqual(dense.name, 'dense') + self.assertEqual(dense.name, 'dense_1') dense = core_layers.Dense(2, activation=nn_ops.relu) dense.apply(np.random.randn(0, 2)) - self.assertEqual(dense.name, 'dense_1') + self.assertEqual(dense.name, 'dense_2') def testCall(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense') @@ -62,8 +60,6 @@ class DenseTest(test.TestCase): self.assertListEqual(dense.variables, [dense.kernel, dense.bias]) self.assertListEqual(dense.trainable_variables, [dense.kernel, dense.bias]) self.assertListEqual(dense.non_trainable_variables, []) - self.assertListEqual(dense._trainable_variables, [dense.kernel, dense.bias]) - self.assertListEqual(dense._non_trainable_variables, []) self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2) self.assertEqual(dense.kernel.name, 'my_dense/kernel:0') @@ -89,8 +85,6 @@ class DenseTest(test.TestCase): self.assertListEqual(dense.non_trainable_variables, [dense.kernel, dense.bias]) self.assertListEqual(dense.trainable_variables, []) - self.assertListEqual(dense._trainable_variables, [dense.kernel, dense.bias]) - self.assertListEqual(dense._non_trainable_variables, []) self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0) @@ -289,7 +283,7 @@ class DenseTest(test.TestCase): class DropoutTest(test.TestCase): def testDropoutProperties(self): - dp = core_layers.Dropout(0.5) + dp = core_layers.Dropout(0.5, name='dropout') self.assertEqual(dp.rate, 0.5) self.assertEqual(dp.noise_shape, None) dp.apply(np.ones(())) diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 41846ae0cd..2970ddb8ce 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -41,7 +41,7 @@ from tensorflow.python.layers import base from tensorflow.python.layers import utils -class BatchNormalization(base._Layer): # pylint: disable=protected-access +class BatchNormalization(base.Layer): """Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing @@ -143,33 +143,33 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access input_shape) if self.center: - self.beta = vs.get_variable('beta', - shape=(param_dim,), - initializer=self.beta_initializer, - regularizer=self.beta_regularizer, - trainable=True) + self.beta = self.add_variable(name='beta', + shape=(param_dim,), + initializer=self.beta_initializer, + regularizer=self.beta_regularizer, + trainable=True) else: self.beta = None if self.scale: - self.gamma = vs.get_variable('gamma', - shape=(param_dim,), - initializer=self.gamma_initializer, - regularizer=self.gamma_regularizer, - trainable=True) + self.gamma = self.add_variable(name='gamma', + shape=(param_dim,), + initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, + trainable=True) else: self.gamma = None # Disable variable partitioning when creating the moving mean and variance - partitioner = vs.get_variable_scope().partitioner + partitioner = self._scope.partitioner try: - vs.get_variable_scope().set_partitioner(None) - self.moving_mean = vs.get_variable( - 'moving_mean', + self._scope.set_partitioner(None) + self.moving_mean = self.add_variable( + name='moving_mean', shape=(param_dim,), initializer=self.moving_mean_initializer, trainable=False) - self.moving_variance = vs.get_variable( - 'moving_variance', + self.moving_variance = self.add_variable( + name='moving_variance', shape=(param_dim,), initializer=self.moving_variance_initializer, trainable=False) @@ -182,10 +182,10 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): - var = vs.get_variable(name, - shape=shape, - initializer=init_ops.zeros_initializer(), - trainable=False) + var = self.add_variable(name=name, + shape=shape, + initializer=init_ops.zeros_initializer(), + trainable=False) return var with ops.device(None): with ops.device(lambda _: self.moving_mean.device): @@ -200,7 +200,7 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access self.renorm_stddev_weight = _renorm_variable( 'renorm_stddev_weight', ()) finally: - vs.get_variable_scope().set_partitioner(partitioner) + self._scope.set_partitioner(partitioner) def _renorm_correction_and_moments(self, mean, variance, training): """Returns the correction and update values for renorm.""" @@ -313,11 +313,8 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access self.moving_variance, new_variance, decay, zero_debias=False) if not self.updates: - # In the future this should be refactored into a self.add_update - # methods in order to allow for instance-based BN layer sharing - # across unrelated input streams (e.g. like in Keras). - self.updates.append(mean_update) - self.updates.append(variance_update) + self.add_update(mean_update) + self.add_update(variance_update) else: mean, variance = self.moving_mean, self.moving_variance diff --git a/tensorflow/python/layers/pooling.py b/tensorflow/python/layers/pooling.py index 3e40423ad6..b819372923 100644 --- a/tensorflow/python/layers/pooling.py +++ b/tensorflow/python/layers/pooling.py @@ -36,7 +36,7 @@ from tensorflow.python.layers import base from tensorflow.python.layers import utils -class _Pooling1D(base._Layer): # pylint: disable=protected-access +class _Pooling1D(base.Layer): """Pooling layer for arbitrary pooling functions, for 1D inputs. This class only exists for code reuse. It will never be an exposed API. @@ -222,7 +222,7 @@ def max_pooling1d(inputs, pool_size, strides, return layer.apply(inputs) -class _Pooling2D(base._Layer): # pylint: disable=protected-access +class _Pooling2D(base.Layer): """Pooling layer for arbitrary pooling functions, for 2D inputs (e.g. images). This class only exists for code reuse. It will never be an exposed API. @@ -407,7 +407,7 @@ def max_pooling2d(inputs, return layer.apply(inputs) -class _Pooling3D(base._Layer): # pylint: disable=protected-access +class _Pooling3D(base.Layer): """Pooling layer for arbitrary pooling functions, for 3D inputs. This class only exists for code reuse. It will never be an exposed API. diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 32ebe0c2e8..4810e97b36 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -75,7 +75,7 @@ def _zero_state_tensors(state_size, batch_size, dtype): return zeros -class _RNNCell(base_layer._Layer): # pylint: disable=protected-access +class _RNNCell(base_layer.Layer): # pylint: disable=protected-access """Abstract object representing an RNN cell. Every `RNNCell` must have the properties below and implement `__call__` with -- GitLab From c4ce79d371364a6f10a55e952db1cc718f60f2e5 Mon Sep 17 00:00:00 2001 From: Jonathan Hseu Date: Wed, 26 Apr 2017 16:23:55 -0800 Subject: [PATCH 029/697] Automated rollback of change 154338968 Change: 154367968 --- tensorflow/python/lib/io/file_io.i | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tensorflow/python/lib/io/file_io.i b/tensorflow/python/lib/io/file_io.i index a6fe802597..c0c4e035fc 100644 --- a/tensorflow/python/lib/io/file_io.i +++ b/tensorflow/python/lib/io/file_io.i @@ -31,13 +31,6 @@ limitations under the License. #include "tensorflow/core/protobuf/meta_graph.pb.h" %} -// Release the Python GIL for the duration of all methods. -%exception { - Py_BEGIN_ALLOW_THREADS; - $action - Py_END_ALLOW_THREADS; -} - %{ inline void FileExists(const string& filename, TF_Status* out_status) { tensorflow::Status status = tensorflow::Env::Default()->FileExists(filename); @@ -306,6 +299,3 @@ string ReadFromStream(tensorflow::io::BufferedInputStream* stream, %include "tensorflow/core/lib/io/path.h" %include "tensorflow/core/platform/file_statistics.h" - -// Delete the previously defined default handler that releases the Python GIL. -%noexception; -- GitLab From 54b5650eb4445548eb1d0e98346e4563b3db0f10 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 26 Apr 2017 16:37:52 -0800 Subject: [PATCH 030/697] Name enums in C API (in addition to typedefs). SWIG will sometimes generate type references like "enum TF_DataType", which doesn't compile if only the typedef is named. Change: 154369467 --- tensorflow/c/c_api.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index e2aeef0d88..ec9b01b388 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -95,7 +95,7 @@ TF_CAPI_EXPORT extern const char* TF_Version(); // -------------------------------------------------------------------------- // TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. // The enum values here are identical to corresponding values in types.proto. -typedef enum { +typedef enum TF_DataType { TF_FLOAT = 1, TF_DOUBLE = 2, TF_INT32 = 3, // Int32 tensors are always in 'host' memory. @@ -127,7 +127,7 @@ TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt); // -------------------------------------------------------------------------- // TF_Code holds an error code. The enum values here are identical to // corresponding values in error_codes.proto. -typedef enum { +typedef enum TF_Code { TF_OK = 0, TF_CANCELLED = 1, TF_UNKNOWN = 2, @@ -629,7 +629,7 @@ TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs( int max_control_outputs); // TF_AttrType describes the type of the value of an attribute on an operation. -typedef enum { +typedef enum TF_AttrType { TF_ATTR_STRING = 0, TF_ATTR_INT = 1, TF_ATTR_FLOAT = 2, -- GitLab From ed37842fe4b1a83922c0042f59775968f350dc8a Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Wed, 26 Apr 2017 17:00:54 -0800 Subject: [PATCH 031/697] Lower number of parallel jobs to prevent OOM. Change: 154371550 --- tensorflow/contrib/makefile/build_helper.subr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/makefile/build_helper.subr b/tensorflow/contrib/makefile/build_helper.subr index f0452944e2..d58b2c0a9b 100644 --- a/tensorflow/contrib/makefile/build_helper.subr +++ b/tensorflow/contrib/makefile/build_helper.subr @@ -31,7 +31,7 @@ get_cpu_count() { } get_job_count() { - echo $(($(get_cpu_count) * 2)) + echo $(($(get_cpu_count))) } make_host_protoc() { -- GitLab From f37eb35d82d734504da273fed7a5a0fae358c9bd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Apr 2017 17:11:54 -0800 Subject: [PATCH 032/697] Go back to constructing dashboards via the DOM. Change: 154372639 --- .../tf_tensorboard/tf-tensorboard.html | 170 +++++++++++------- 1 file changed, 103 insertions(+), 67 deletions(-) diff --git a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html index a1925c1f6b..b5b2e2d5a8 100644 --- a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html +++ b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html @@ -57,9 +57,9 @@ allows the user to toggle between various dashboards.
TensorBoard
- + diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group-demo.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group-demo.html index 051a58e270..3565fec179 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group-demo.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-regex-group-demo.html @@ -19,7 +19,6 @@ limitations under the License. - +

Answer to the Ultimate Question of Life, the Universe, and Everything

+ + + diff --git a/tensorflow/tensorboard/components/tf_graph_app_d3v4/index.html b/tensorflow/tensorboard/components/tf_graph_app_d3v4/index.html new file mode 100644 index 0000000000..c80fbf4f63 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_app_d3v4/index.html @@ -0,0 +1,30 @@ + + + + + + vz-vega + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_app_d3v4/tf-graph-app.html b/tensorflow/tensorboard/components/tf_graph_app_d3v4/tf-graph-app.html new file mode 100644 index 0000000000..915b54a06a --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_app_d3v4/tf-graph-app.html @@ -0,0 +1,152 @@ + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_board_d3v4/BUILD b/tensorflow/tensorboard/components/tf_graph_board_d3v4/BUILD new file mode 100644 index 0000000000..7203a9333b --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_board_d3v4/BUILD @@ -0,0 +1,28 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") + +licenses(["notice"]) # Apache 2.0 + +web_library( + name = "tf_graph_board_d3v4", + srcs = [ + "tf-graph-board.html", + ], + path = "/tf-graph-board", + deps = [ + "//tensorflow/tensorboard/components/tf_graph_common_d3v4", + "//tensorflow/tensorboard/components/tf_graph_d3v4", + "//tensorflow/tensorboard/components/tf_graph_info_d3v4", + "@org_polymer", + "@org_polymer_paper_progress", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph_board_d3v4/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_board_d3v4/demo/BUILD new file mode 100644 index 0000000000..2d668769e6 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_board_d3v4/demo/BUILD @@ -0,0 +1,26 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") + +licenses(["notice"]) # Apache 2.0 + +# bazel run //third_party/tensorflow/tensorboard/components/tf_graph_board/demo +web_library( + name = "demo", + srcs = ["index.html"] + glob(["data/**"]), + path = "/tf-graph-board/demo", + deps = [ + "//tensorflow/tensorboard/components/tf_graph_board", + "//tensorflow/tensorboard/components/tf_graph_common", + "//tensorflow/tensorboard/components/tf_graph_loader", + "@org_polymer_iron_demo_helpers", + "@org_polymer_paper_styles", + "@org_polymer_webcomponentsjs", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph_board_d3v4/demo/data/graph.pbtxt b/tensorflow/tensorboard/components/tf_graph_board_d3v4/demo/data/graph.pbtxt new file mode 100644 index 0000000000..30b2064534 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_board_d3v4/demo/data/graph.pbtxt @@ -0,0 +1,4606 @@ +node { + name: "GradientDescent/learning_rate" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_3" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.1 + } + } + } +} +node { + name: "gradients/add_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 100 + } + } + } +} +node { + name: "gradients/add_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000d\000\000\000" + } + } + } +} +node { + name: "gradients/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_grad/Shape" + input: "gradients/add_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } +} +node { + name: "gradients/add_1_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "gradients/add_1_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_1_grad/Shape" + input: "gradients/add_1_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } +} +node { + name: "gradients/Reshape_1_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } +} +node { + name: "gradients/Reshape_3_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Maximum/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "gradients/Mean_grad/Const_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "gradients/Mean_grad/Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "gradients/Mean_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/Prod_1" + op: "Prod" + input: "gradients/Mean_grad/Shape_1" + input: "gradients/Mean_grad/Const_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/Mean_grad/Maximum" + op: "Maximum" + input: "gradients/Mean_grad/Prod_1" + input: "gradients/Mean_grad/Maximum/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Prod" + op: "Prod" + input: "gradients/Mean_grad/Shape" + input: "gradients/Mean_grad/Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/Mean_grad/floordiv" + op: "FloorDiv" + input: "gradients/Mean_grad/Prod" + input: "gradients/Mean_grad/Maximum" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Cast" + op: "Cast" + input: "gradients/Mean_grad/floordiv" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Tile/multiples" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Reshape/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "gradients/Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "gradients/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "gradients/Fill" + op: "Fill" + input: "gradients/Shape" + input: "gradients/Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Reshape" + op: "Reshape" + input: "gradients/Fill" + input: "gradients/Mean_grad/Reshape/shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/Tile" + op: "Tile" + input: "gradients/Mean_grad/Reshape" + input: "gradients/Mean_grad/Tile/multiples" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/truediv" + op: "RealDiv" + input: "gradients/Mean_grad/Tile" + input: "gradients/Mean_grad/Cast" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/Reshape_3_grad/Reshape" + op: "Reshape" + input: "gradients/Mean_grad/truediv" + input: "gradients/Reshape_3_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "gradients/Reshape_3_grad/Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 1 + } + } + } + } + } +} +node { + name: "Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "Slice_2/begin" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "Sub_2/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "concat_1/axis" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "concat_1/values_0" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } +} +node { + name: "Slice_1/size" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "Sub_1/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Shape_2" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank_2" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub_1" + op: "Sub" + input: "Rank_2" + input: "Sub_1/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice_1/begin" + op: "Pack" + input: "Sub_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice_1" + op: "Slice" + input: "Shape_2" + input: "Slice_1/begin" + input: "Slice_1/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "concat_1" + op: "ConcatV2" + input: "concat_1/values_0" + input: "Slice_1" + input: "concat_1/axis" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } +} +node { + name: "concat/axis" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "concat/values_0" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } +} +node { + name: "Slice/size" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "Sub/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub" + op: "Sub" + input: "Rank_1" + input: "Sub/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice/begin" + op: "Pack" + input: "Sub" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice" + op: "Slice" + input: "Shape_1" + input: "Slice/begin" + input: "Slice/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "concat" + op: "ConcatV2" + input: "concat/values_0" + input: "Slice" + input: "concat/axis" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } +} +node { + name: "Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub_2" + op: "Sub" + input: "Rank" + input: "Sub_2/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice_2/size" + op: "Pack" + input: "Sub_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice_2" + op: "Slice" + input: "Shape" + input: "Slice_2/begin" + input: "Slice_2/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "logits_biases" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "logits_biases/read" + op: "Identity" + input: "logits_biases" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "logits_weights" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "logits_weights/read" + op: "Identity" + input: "logits_weights" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "hidden_biases" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "hidden_biases/read" + op: "Identity" + input: "hidden_biases" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "hidden_weights" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "hidden_weights/read" + op: "Identity" + input: "hidden_weights" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "Reshape/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\377\377\377\377" + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/depth" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 10 + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/off_value" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0 + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/on_value" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 200 + } + } + } +} +node { + name: "mnist_dataset_train_1/random_shuffle_queue" + op: "RandomShuffleQueueV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "capacity" + value { + i: 20000 + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + type: DT_INT64 + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "min_after_dequeue" + value { + i: 4000 + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + attr { + key: "shapes" + value { + list { + shape { + dim { + size: 28 + } + dim { + size: 28 + } + dim { + size: 1 + } + } + shape { + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" + op: "QueueDequeueManyV2" + input: "mnist_dataset_train_1/random_shuffle_queue" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + shape { + unknown_rank: true + } + } + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + type: DT_INT64 + } + } + } + attr { + key: "timeout_ms" + value { + i: -1 + } + } +} +node { + name: "Reshape" + op: "Reshape" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" + input: "Reshape/shape" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: -1 + } + } + } + } + } +} +node { + name: "MatMul" + op: "MatMul" + input: "Reshape" + input: "hidden_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "add" + op: "Add" + input: "MatMul" + input: "hidden_biases/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "Relu" + op: "Relu" + input: "add" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "MatMul_1" + op: "MatMul" + input: "Relu" + input: "logits_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "add_1" + op: "Add" + input: "MatMul_1" + input: "logits_biases/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "Reshape_1" + op: "Reshape" + input: "add_1" + input: "concat" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot" + op: "OneHot" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany:1" + input: "mnist_dataset_train_2/one_hot/depth" + input: "mnist_dataset_train_2/one_hot/on_value" + input: "mnist_dataset_train_2/one_hot/off_value" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "TI" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "axis" + value { + i: -1 + } + } +} +node { + name: "Reshape_2" + op: "Reshape" + input: "mnist_dataset_train_2/one_hot" + input: "concat_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "SoftmaxCrossEntropyWithLogits" + op: "SoftmaxCrossEntropyWithLogits" + input: "Reshape_1" + input: "Reshape_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" + op: "PreventGradient" + input: "SoftmaxCrossEntropyWithLogits:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "message" + value { + s: "Currently there is no way to take the second derivative of softmax_cross_entropy_with_logits due to the fused implementation\'s interaction with tf.gradients()" + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/Reshape_1_grad/Reshape" + op: "Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + input: "gradients/Reshape_1_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Sum_1" + op: "Sum" + input: "gradients/Reshape_1_grad/Reshape" + input: "gradients/add_1_grad/BroadcastGradientArgs:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_1_grad/Sum_1" + input: "gradients/add_1_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Sum" + op: "Sum" + input: "gradients/Reshape_1_grad/Reshape" + input: "gradients/add_1_grad/BroadcastGradientArgs" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/Reshape" + op: "Reshape" + input: "gradients/add_1_grad/Sum" + input: "gradients/add_1_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_1_grad/Reshape" + input: "^gradients/add_1_grad/Reshape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/add_1_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_1_grad/Reshape_1" + input: "^gradients/add_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_1_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "GradientDescent/update_logits_biases/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "logits_biases" + input: "GradientDescent/learning_rate" + input: "gradients/add_1_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_1_grad/Reshape" + input: "^gradients/add_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_1_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/MatMul_1_grad/MatMul_1" + op: "MatMul" + input: "Relu" + input: "gradients/add_1_grad/tuple/control_dependency" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_1_grad/MatMul" + op: "MatMul" + input: "gradients/add_1_grad/tuple/control_dependency" + input: "logits_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_1_grad/MatMul" + input: "^gradients/MatMul_1_grad/MatMul_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_1_grad/MatMul_1" + input: "^gradients/MatMul_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_1_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "GradientDescent/update_logits_weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "logits_weights" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_1_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/MatMul_1_grad/MatMul" + input: "^gradients/MatMul_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_1_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/Relu_grad/ReluGrad" + op: "ReluGrad" + input: "gradients/MatMul_1_grad/tuple/control_dependency" + input: "Relu" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/Sum_1" + op: "Sum" + input: "gradients/Relu_grad/ReluGrad" + input: "gradients/add_grad/BroadcastGradientArgs:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_grad/Sum_1" + input: "gradients/add_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/Sum" + op: "Sum" + input: "gradients/Relu_grad/ReluGrad" + input: "gradients/add_grad/BroadcastGradientArgs" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/Reshape" + op: "Reshape" + input: "gradients/add_grad/Sum" + input: "gradients/add_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_grad/Reshape" + input: "^gradients/add_grad/Reshape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_grad/Reshape_1" + input: "^gradients/add_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "GradientDescent/update_hidden_biases/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "hidden_biases" + input: "GradientDescent/learning_rate" + input: "gradients/add_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_grad/Reshape" + input: "^gradients/add_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/MatMul_grad/MatMul_1" + op: "MatMul" + input: "Reshape" + input: "gradients/add_grad/tuple/control_dependency" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_grad/MatMul" + op: "MatMul" + input: "gradients/add_grad/tuple/control_dependency" + input: "hidden_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } +} +node { + name: "gradients/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/MatMul_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_grad/MatMul_1" + input: "^gradients/MatMul_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "GradientDescent/update_hidden_weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "hidden_weights" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "GradientDescent" + op: "NoOp" + input: "^GradientDescent/update_hidden_weights/ApplyGradientDescent" + input: "^GradientDescent/update_hidden_biases/ApplyGradientDescent" + input: "^GradientDescent/update_logits_weights/ApplyGradientDescent" + input: "^GradientDescent/update_logits_biases/ApplyGradientDescent" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_2" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "Reshape_3" + op: "Reshape" + input: "SoftmaxCrossEntropyWithLogits" + input: "Slice_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "Mean" + op: "Mean" + input: "Reshape_3" + input: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "_send_Mean_0" + op: "_Send" + input: "Mean" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "client_terminated" + value { + b: true + } + } + attr { + key: "recv_device" + value { + s: "/job:localhost/replica:0/task:0/cpu:0" + } + } + attr { + key: "send_device" + value { + s: "/job:localhost/replica:0/task:0/cpu:0" + } + } + attr { + key: "send_device_incarnation" + value { + i: -5924635994370253548 + } + } + attr { + key: "tensor_name" + value { + s: "Mean:0" + } + } +} +library { +} +versions { + producer: 21 +} diff --git a/tensorflow/tensorboard/components/tf_graph_board_d3v4/demo/index.html b/tensorflow/tensorboard/components/tf_graph_board_d3v4/demo/index.html new file mode 100644 index 0000000000..2563e1595e --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_board_d3v4/demo/index.html @@ -0,0 +1,98 @@ + + + + + + + +TF Graph Board Demo + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_board_d3v4/tf-graph-board.html b/tensorflow/tensorboard/components/tf_graph_board_d3v4/tf-graph-board.html new file mode 100644 index 0000000000..0ee694e1e6 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_board_d3v4/tf-graph-board.html @@ -0,0 +1,255 @@ + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/BUILD b/tensorflow/tensorboard/components/tf_graph_common_d3v4/BUILD new file mode 100644 index 0000000000..9b7dcef8bc --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/BUILD @@ -0,0 +1,41 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") + +licenses(["notice"]) # Apache 2.0 + +web_library( + name = "tf_graph_common_d3v4", + srcs = [ + "tf-graph-common.html", + ":ts", + ], + path = "/tf-graph-common", + deps = [ + "//tensorflow/tensorboard/components/tf_imports_d3v4:d3", + "//tensorflow/tensorboard/components/tf_imports_d3v4:dagre", + "//tensorflow/tensorboard/components/tf_imports_d3v4:graphlib", + "//tensorflow/tensorboard/components/tf_imports_d3v4:lodash", + "@org_polymer", + ], +) + +tensorboard_typescript_genrule( + name = "ts", + srcs = glob(["*.ts"]), + typings = [ + "//tensorflow/tensorboard/components/tf_imports_d3v4:d3.d.ts", + "@org_definitelytyped//:lodash.d.ts", + "@org_definitelytyped//:polymer.d.ts", + "@org_definitelytyped//:webcomponents.js.d.ts", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/annotation.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/annotation.ts new file mode 100644 index 0000000000..6db0cd5519 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/annotation.ts @@ -0,0 +1,239 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +module tf.graph.scene.annotation { + /** + * Populate a given annotation container group + * + * + * + * with annotation group of the following structure: + * + * + * + * + * + * + * + * @param container selection of the container. + * @param annotationData node.{in|out}Annotations + * @param d node to build group for. + * @param sceneElement polymer element. + * @return selection of appended objects + */ + export function buildGroup( + container, annotationData: render.AnnotationList, + d: render.RenderNodeInfo, sceneElement) { + // Select all children and join with data. + let annotationGroups = + container + .selectAll(function() { + // using d3's selector function + // See https://github.com/mbostock/d3/releases/tag/v2.0.0 + // (It's not listed in the d3 wiki.) + return this.childNodes; + }) + .data(annotationData.list, d => { return d.node.name; }); + + annotationGroups.enter() + .append('g') + .attr('data-name', a => { return a.node.name; }) + .each(function(a) { + let aGroup = d3.select(this); + + // Add annotation to the index in the scene + sceneElement.addAnnotationGroup(a, d, aGroup); + // Append annotation edge + let edgeType = Class.Annotation.EDGE; + let metaedge = a.renderMetaedgeInfo && a.renderMetaedgeInfo.metaedge; + if (metaedge && !metaedge.numRegularEdges) { + edgeType += ' ' + Class.Annotation.CONTROL_EDGE; + } + // If any edges are reference edges, add the reference edge class. + if (metaedge && metaedge.numRefEdges) { + edgeType += ' ' + Class.Edge.REF_LINE; + } + edge.appendEdge(aGroup, a, sceneElement, edgeType); + + if (a.annotationType !== render.AnnotationType.ELLIPSIS) { + addAnnotationLabelFromNode(aGroup, a); + buildShape(aGroup, a); + } else { + addAnnotationLabel( + aGroup, a.node.name, a, Class.Annotation.ELLIPSIS); + } + }); + + annotationGroups + .attr( + 'class', + a => { + return Class.Annotation.GROUP + ' ' + + annotationToClassName(a.annotationType) + ' ' + + node.nodeClass(a); + }) + .each(function(a) { + let aGroup = d3.select(this); + update(aGroup, d, a, sceneElement); + if (a.annotationType !== render.AnnotationType.ELLIPSIS) { + addInteraction(aGroup, d, a, sceneElement); + } + }); + + annotationGroups.exit() + .each(function(a) { + let aGroup = d3.select(this); + + // Remove annotation from the index in the scene + sceneElement.removeAnnotationGroup(a, d, aGroup); + }) + .remove(); + return annotationGroups; +}; + +/** + * Maps an annotation enum to a class name used in css rules. + */ +function annotationToClassName(annotationType: render.AnnotationType) { + return (render.AnnotationType[annotationType] || '').toLowerCase() || null; +} + +function buildShape(aGroup, a: render.Annotation) { + if (a.annotationType === render.AnnotationType.SUMMARY) { + let summary = selectOrCreateChild(aGroup, 'use'); + summary + .attr('class', 'summary') + .attr('xlink:href', '#summary-icon') + .attr('cursor', 'pointer'); + } else { + let shape = node.buildShape(aGroup, a, Class.Annotation.NODE); + // add title tag to get native tooltips + selectOrCreateChild(shape, 'title').text(a.node.name); + } +} + +function addAnnotationLabelFromNode(aGroup, a: render.Annotation) { + let namePath = a.node.name.split('/'); + let text = namePath[namePath.length - 1]; + return addAnnotationLabel(aGroup, text, a, null); +} + +function addAnnotationLabel( + aGroup, label: string, a: render.Annotation, additionalClassNames) { + let classNames = Class.Annotation.LABEL; + if (additionalClassNames) { + classNames += ' ' + additionalClassNames; + } + let txtElement = aGroup.append('text') + .attr('class', classNames) + .attr('dy', '.35em') + .attr('text-anchor', a.isIn ? 'end' : 'start') + .text(label); + + return tf.graph.scene.node.enforceLabelWidth(txtElement, -1); +} + +function addInteraction(selection, d: render.RenderNodeInfo, + annotation: render.Annotation, sceneElement) { + selection + .on('mouseover', + a => { + sceneElement.fire( + 'annotation-highlight', + {name: a.node.name, hostName: d.node.name}); + }) + .on('mouseout', + a => { + sceneElement.fire( + 'annotation-unhighlight', + {name: a.node.name, hostName: d.node.name}); + }) + .on('click', a => { + // Stop this event's propagation so that it isn't also considered a + // graph-select. + (d3.event).stopPropagation(); + sceneElement.fire( + 'annotation-select', {name: a.node.name, hostName: d.node.name}); + }); + if (annotation.annotationType !== render.AnnotationType.SUMMARY && + annotation.annotationType !== render.AnnotationType.CONSTANT) { + selection.on( + 'contextmenu', contextmenu.getMenu( + node.getContextMenu(annotation.node, sceneElement))); + } +}; + +/** + * Adjust annotation's position. + * + * @param aGroup selection of a 'g.annotation' element. + * @param d Host node data. + * @param a annotation node data. + * @param sceneElement polymer element. + */ +function update(aGroup, d: render.RenderNodeInfo, a: render.Annotation, + sceneElement) { + let cx = layout.computeCXPositionOfNodeShape(d); + // Annotations that point to embedded nodes (constants,summary) + // don't have a render information attached so we don't stylize these. + // Also we don't stylize ellipsis annotations (the string '... and X more'). + if (a.renderNodeInfo && + a.annotationType !== render.AnnotationType.ELLIPSIS) { + node.stylize(aGroup, a.renderNodeInfo, sceneElement, + Class.Annotation.NODE); + } + + if (a.annotationType === render.AnnotationType.SUMMARY) { + // Update the width of the annotation to give space for the image. + a.width += 10; + } + + // label position + aGroup.select('text.' + Class.Annotation.LABEL).transition().attr({ + x: cx + a.dx + (a.isIn ? -1 : 1) * (a.width / 2 + a.labelOffset), + y: d.y + a.dy + }); + + // Some annotations (such as summary) are represented using a 12x12 image tag. + // Purposely omitted units (e.g. pixels) since the images are vector graphics. + // If there is an image, we adjust the location of the image to be vertically + // centered with the node and horizontally centered between the arrow and the + // text label. + aGroup.select('use.summary').transition().attr({ + x: cx + a.dx - 3, + y: d.y + a.dy - 6 + }); + + // Node position (only one of the shape selection will be non-empty.) + positionEllipse( + aGroup.select('.' + Class.Annotation.NODE + ' ellipse'), cx + a.dx, + d.y + a.dy, a.width, a.height); + positionRect( + aGroup.select('.' + Class.Annotation.NODE + ' rect'), cx + a.dx, + d.y + a.dy, a.width, a.height); + positionRect( + aGroup.select('.' + Class.Annotation.NODE + ' use'), cx + a.dx, + d.y + a.dy, a.width, a.height); + + // Edge position + aGroup.select('path.' + Class.Annotation.EDGE).transition().attr('d', a => { + // map relative position to absolute position + let points = a.points.map(p => { return {x: p.dx + cx, y: p.dy + d.y}; }); + return edge.interpolate(points); + }); +}; + +} // close module diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/colors.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/colors.ts new file mode 100644 index 0000000000..40f91f7d2d --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/colors.ts @@ -0,0 +1,130 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +module tf { + /** + * Mapping from color palette name to color palette, which contains + * exact colors for multiple states of a single color palette. + */ + export let COLORS = [ + { + 'name': 'Google Blue', + 'color': '#4184f3', + 'active': '#3a53c5', + 'disabled': '#cad8fc' + }, + { + 'name': 'Google Red', + 'color': '#db4437', + 'active': '#8f2a0c', + 'disabled': '#e8c6c1' + }, + { + 'name': 'Google Yellow', + 'color': '#f4b400', + 'active': '#db9200', + 'disabled': '#f7e8b0' + }, + { + 'name': 'Google Green', + 'color': '#0f9d58', + 'active': '#488046', + 'disabled': '#c2e1cc' + }, + { + 'name': 'Purple', + 'color': '#aa46bb', + 'active': '#5c1398', + 'disabled': '#d7bce6' + }, + { + 'name': 'Teal', + 'color': '#00abc0', + 'active': '#47828e', + 'disabled': '#c2eaf2' + }, + { + 'name': 'Deep Orange', + 'color': '#ff6f42', + 'active': '#ca4a06', + 'disabled': '#f2cbba' + }, + { + 'name': 'Lime', + 'color': '#9d9c23', + 'active': '#7f771d', + 'disabled': '#f1f4c2' + }, + { + 'name': 'Indigo', + 'color': '#5b6abf', + 'active': '#3e47a9', + 'disabled': '#c5c8e8' + }, + { + 'name': 'Pink', + 'color': '#ef6191', + 'active': '#ca1c60', + 'disabled': '#e9b9ce' + }, + { + 'name': 'Deep Teal', + 'color': '#00786a', + 'active': '#2b4f43', + 'disabled': '#bededa' + }, + { + 'name': 'Deep Pink', + 'color': '#c1175a', + 'active': '#75084f', + 'disabled': '#de8cae' + }, + { + 'name': 'Gray', + 'color': '#9E9E9E', // 500 + 'active': '#424242', // 800 + 'disabled': 'F5F5F5' // 100 + } + ].reduce((m, c) => { + m[c.name] = c; + return m; + }, {}); + + /** + * Mapping from op category to color palette name + * e.g., OP_GROUP_COLORS['state_ops'] = 'Google Blue'; + */ + export let OP_GROUP_COLORS = [ + { + color: 'Google Red', + groups: [ + 'gen_legacy_ops', 'legacy_ops', 'legacy_flogs_input', + 'legacy_image_input', 'legacy_input_example_input', + 'legacy_sequence_input', 'legacy_seti_input_input' + ] + }, + {color: 'Deep Orange', groups: ['constant_ops']}, + {color: 'Indigo', groups: ['state_ops']}, + {color: 'Purple', groups: ['nn_ops', 'nn']}, + {color: 'Google Green', groups: ['math_ops']}, + {color: 'Lime', groups: ['array_ops']}, + {color: 'Teal', groups: ['control_flow_ops', 'data_flow_ops']}, + {color: 'Pink', groups: ['summary_ops']}, + {color: 'Deep Pink', groups: ['io_ops']} + ].reduce((m, c) => { + c.groups.forEach(function(group) { m[group] = c.color; }); + return m; + }, {}); +} diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/common.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/common.ts new file mode 100644 index 0000000000..e7eac54e58 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/common.ts @@ -0,0 +1,31 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** + * @fileoverview Common interfaces for the tensorflow graph visualizer. + */ + +module tf { + /** + * Tracks task progress. Each task being passed a progress tracker needs + * to call the below-defined methods to notify the caller about the gradual + * progress of the task. + */ + export interface ProgressTracker { + updateProgress(incrementValue: number): void; + setMessage(msg: string): void; + reportError(msg: string, err: Error): void; + } +} // close module tf diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/contextmenu.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/contextmenu.ts new file mode 100644 index 0000000000..8121cf9f6d --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/contextmenu.ts @@ -0,0 +1,75 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +module tf.graph.scene.contextmenu { + +/** Function that converts data to a title string. */ +export interface TitleFunction { + (data: any): string; +} + +/** Function that takes action based on item clicked in the context menu. */ +export interface ActionFunction { + (elem: any, d: any, i: number): void; +} + +/** + * The interface for an item in the context menu + */ +export interface ContextMenuItem { + title: TitleFunction; + action: ActionFunction; +} + +/** + * Returns the event listener, which can be used as an argument for the d3 + * selection.on function. Renders the context menu that is to be displayed + * in response to the event. + */ +export function getMenu(menu: ContextMenuItem[]) { + let menuSelection = d3.select('.context-menu'); + // Close the menu when anything else is clicked. + d3.select('body').on( + 'click.context', function() { menuSelection.style('display', 'none'); }); + + // Function called to populate the context menu. + return function(data, index: number): void { + // Position and display the menu. + let event = d3.event; + menuSelection + .style('display', 'block') + .style('left', (event.layerX + 1) + 'px') + .style('top', (event.layerY + 1) + 'px'); + + // Stop the event from propagating further. + event.preventDefault(); + event.stopPropagation(); + + // Add provided items to the context menu. + menuSelection.html(''); + let list = menuSelection.append('ul'); + list.selectAll('li') + .data(menu) + .enter() + .append('li') + .html(function(d) { return d.title(data); }) + .on('click', (d, i) => { + d.action(this, data, index); + menuSelection.style('display', 'none'); + }); + }; +}; + +} // close module diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/edge.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/edge.ts new file mode 100644 index 0000000000..f3768e169b --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/edge.ts @@ -0,0 +1,343 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +module tf.graph.scene.edge { + +/** Delimiter between dimensions when showing sizes of tensors. */ +const TENSOR_SHAPE_DELIM = '×'; + +/** The minimum stroke width of an edge. */ +export const MIN_EDGE_WIDTH = 0.75; + +/** The maximum stroke width of an edge. */ +export const MAX_EDGE_WIDTH = 12; + +/** The exponent used in the power scale for edge thickness. */ +const EDGE_WIDTH_SCALE_EXPONENT = 0.3; + +/** The domain (min and max value) for the edge width. */ +const DOMAIN_EDGE_WIDTH_SCALE = [1, 5E6]; + +export const EDGE_WIDTH_SCALE = d3.scalePow() + .exponent(EDGE_WIDTH_SCALE_EXPONENT) + .domain(DOMAIN_EDGE_WIDTH_SCALE) + .range([MIN_EDGE_WIDTH, MAX_EDGE_WIDTH]) + .clamp(true); + +let arrowheadMap = + d3.scaleQuantize().domain([MIN_EDGE_WIDTH, MAX_EDGE_WIDTH]).range([ + 'small', 'medium', 'large', 'xlarge' + ]); + +/** Minimum stroke width to put edge labels in the middle of edges */ +const CENTER_EDGE_LABEL_MIN_STROKE_WIDTH = 2.5; + +export type EdgeData = {v: string, w: string, label: render.RenderMetaedgeInfo}; + +export function getEdgeKey(edgeObj: EdgeData) { + return edgeObj.v + EDGE_KEY_DELIM + edgeObj.w; +} + +/** + * Select or Create a 'g.edges' group to a given sceneGroup + * and builds a number of 'g.edge' groups inside the group. + * + * Structure Pattern: + * + * + * + * + * + * ... + * + * + * + * @param sceneGroup container + * @param graph + * @param sceneElement polymer element. + * @return selection of the created nodeGroups + */ +export function buildGroup(sceneGroup, + graph: graphlib.Graph, + sceneElement) { + let edges: EdgeData[] = []; + edges = _.reduce(graph.edges(), (edges, edgeObj) => { + let edgeLabel = graph.edge(edgeObj); + edges.push({ + v: edgeObj.v, + w: edgeObj.w, + label: edgeLabel + }); + return edges; + }, edges); + + let container = + scene.selectOrCreateChild(sceneGroup, 'g', Class.Edge.CONTAINER); + + // Select all children and join with data. + // (Note that all children of g.edges are g.edge) + let edgeGroups = (container as any).selectAll('g.edge').data(edges, getEdgeKey); + + // Make edges a group to support rendering multiple lines for metaedge + edgeGroups.enter() + .append('g') + .attr('class', Class.Edge.GROUP) + .attr('data-edge', getEdgeKey) + .each(function(d: EdgeData) { + let edgeGroup = d3.select(this); + d.label.edgeGroup = edgeGroup; + // index node group for quick highlighting + sceneElement._edgeGroupIndex[getEdgeKey(d)] = edgeGroup; + + // Add line during enter because we're assuming that type of line + // normally does not change. + appendEdge(edgeGroup, d, sceneElement); + }) + .merge(edgeGroups) + .each(position) + .each(function(d) { + stylize(d3.select(this), d, sceneElement); + }); + + edgeGroups.exit() + .each(d => { + delete sceneElement._edgeGroupIndex[getEdgeKey(d)]; + }) + .remove(); + return edgeGroups; +}; + +/** + * Returns the label for the given base edge. + * The label is the shape of the underlying tensor. + */ +export function getLabelForBaseEdge( + baseEdge: BaseEdge, renderInfo: render.RenderGraphInfo): string { + let node = renderInfo.getNodeByName(baseEdge.v); + if (node.outputShapes == null || node.outputShapes.length === 0) { + return null; + } + let shape = node.outputShapes[baseEdge.outputTensorIndex]; + if (shape == null) { + return null; + } + if (shape.length === 0) { + return 'scalar'; + } + return shape.map(size => { return size === -1 ? '?' : size; }) + .join(TENSOR_SHAPE_DELIM); +} + +/** + * Creates the label for the given metaedge. If the metaedge consists + * of only 1 tensor, and it's shape is known, the label will contain that + * shape. Otherwise, the label will say the number of tensors in the metaedge. + */ +export function getLabelForEdge(metaedge: Metaedge, + renderInfo: render.RenderGraphInfo): string { + let isMultiEdge = metaedge.baseEdgeList.length > 1; + return isMultiEdge ? + metaedge.baseEdgeList.length + ' tensors' : + getLabelForBaseEdge(metaedge.baseEdgeList[0], renderInfo); +} + +/** + * Shortens the path enought such that the tip of the start/end marker will + * point to the start/end of the path. The marker can be of arbitrary size. + * + * @param points Array of path control points. + * @param marker D3 selection of the svg element. + * @param isStart Is the marker a `start-marker`. If false, the marker is + * an `end-marker`. + * @return The new array of control points. + */ +function adjustPathPointsForMarker(points: render.Point[], + marker: d3.Selection, isStart: boolean): render.Point[] { + let lineFunc = d3.line() + .x(d => d.x) + .y(d => d.y); + let path = + d3.select(document.createElementNS('http://www.w3.org/2000/svg', 'path')) + .attr('d', lineFunc(points)); + let markerWidth = +marker.attr('markerWidth'); + let viewBox = marker.attr('viewBox').split(' ').map(Number); + let viewBoxWidth = viewBox[2] - viewBox[0]; + let refX = +marker.attr('refX'); + let pathNode = path.node(); + if (isStart) { + let fractionStickingOut = refX / viewBoxWidth; + let length = markerWidth * fractionStickingOut; + let point = pathNode.getPointAtLength(length); + // Figure out how many segments of the path we need to remove in order + // to shorten the path. + let segIndex = pathNode.getPathSegAtLength(length); + // Update the very first segment. + points[segIndex - 1] = {x: point.x, y: point.y}; + // Ignore every point before segIndex - 1. + return points.slice(segIndex - 1); + } else { + let fractionStickingOut = 1 - refX / viewBoxWidth; + let length = pathNode.getTotalLength() - markerWidth * fractionStickingOut; + let point = pathNode.getPointAtLength(length); + // Figure out how many segments of the path we need to remove in order + // to shorten the path. + let segIndex = pathNode.getPathSegAtLength(length); + // Update the very last segment. + points[segIndex] = {x: point.x, y: point.y}; + // Ignore every point after segIndex. + return points.slice(0, segIndex + 1); + } +} + +/** + * For a given d3 selection and data object, create a path to represent the + * edge described in d.label. + * + * If d.label is defined, it will be a RenderMetaedgeInfo instance. It + * will sometimes be undefined, for example for some Annotation edges for which + * there is no underlying Metaedge in the hierarchical graph. + */ +export function appendEdge(edgeGroup, d: EdgeData, + sceneElement: {renderHierarchy: render.RenderGraphInfo}, + edgeClass?: string) { + let size = 1; + if (d.label != null && d.label.metaedge != null) { + // There is an underlying Metaedge. + size = d.label.metaedge.totalSize; + } + edgeClass = edgeClass || Class.Edge.LINE; // set default type + + if (d.label && d.label.structural) { + edgeClass += ' ' + Class.Edge.STRUCTURAL; + } + // Give the path a unique id, which will be used to link + // the textPath (edge label) to this path. + let pathId = 'path_' + getEdgeKey(d); + let strokeWidth = sceneElement.renderHierarchy.edgeWidthScale(size); + + let path = edgeGroup.append('path') + .attr('id', pathId) + .attr('class', edgeClass) + .style('stroke-width', strokeWidth + 'px'); + + // Check if there is a reference edge and add an arrowhead of the right size. + if (d.label && d.label.metaedge && d.label.metaedge.numRefEdges) { + let markerId = `ref-arrowhead-${arrowheadMap(strokeWidth)}`; + path.style('marker-start', `url(#${markerId})`); + d.label.startMarkerId = markerId; + } + + if (d.label == null || d.label.metaedge == null) { + // There is no associated metaedge, thus no text. + // This happens for annotation edges. + return; + } + let labelForEdge = getLabelForEdge(d.label.metaedge, + sceneElement.renderHierarchy); + if (labelForEdge == null) { + // We have no information to show on this edge. + return; + } + + // Put edge label in the middle of edge only if the edge is thick enough. + let baseline = strokeWidth > CENTER_EDGE_LABEL_MIN_STROKE_WIDTH ? + 'central' : + 'text-after-edge'; + + edgeGroup.append('text') + .append('textPath') + .attr('xlink:href', '#' + pathId) + .attr('startOffset', '50%') + .attr('text-anchor', 'middle') + .attr('dominant-baseline', 'central') + .text(labelForEdge); +}; + +export let interpolate = d3.line<{x: number, y: number}>() + .curve(d3.curveBasis) + .x((d) => { return d.x;}) + .y((d) => { return d.y;}); + +/** + * Returns a tween interpolator for the endpoint of an edge path. + */ +function getEdgePathInterpolator(d: EdgeData, i: number, a: string) { + let renderMetaedgeInfo = d.label; + let adjoiningMetaedge = renderMetaedgeInfo.adjoiningMetaedge; + let points = renderMetaedgeInfo.points; + + // Adjust the path so that start/end markers point to the end + // of the path. + if (d.label.startMarkerId) { + points = adjustPathPointsForMarker( + points, d3.select('#' + d.label.startMarkerId), true); + } + if (d.label.endMarkerId) { + points = adjustPathPointsForMarker( + points, d3.select('#' + d.label.endMarkerId), false); + } + + if (!adjoiningMetaedge) { + return d3.interpolate(a, interpolate(points)); + } + + let renderPath = this; + + // Get the adjoining path that matches the adjoining metaedge. + let adjoiningPath = + ((adjoiningMetaedge.edgeGroup.node()) + .firstChild); + + // Find the desired SVGPoint along the adjoining path, then convert those + // coordinates into the space of the renderPath using its Current + // Transformation Matrix (CTM). + let inbound = renderMetaedgeInfo.metaedge.inbound; + + return function(t) { + let adjoiningPoint = adjoiningPath + .getPointAtLength(inbound ? adjoiningPath.getTotalLength() : 0) + .matrixTransform(adjoiningPath.getCTM()) + .matrixTransform(renderPath.getCTM().inverse()); + + // Update the relevant point in the renderMetaedgeInfo's points list, then + // re-interpolate the path. + let index = inbound ? 0 : points.length - 1; + points[index].x = adjoiningPoint.x; + points[index].y = adjoiningPoint.y; + let dPath = interpolate(points); + return dPath; + }; +} + +function position(d) { + d3.select(this) + .select('path.' + Class.Edge.LINE) + .transition() + .attrTween('d', getEdgePathInterpolator as any); +}; + +/** + * For a given d3 selection and data object, mark the edge as a control + * dependency if it contains only control edges. + * + * d's label property will be a RenderMetaedgeInfo object. + */ +function stylize(edgeGroup, d: EdgeData, stylize) { + edgeGroup.classed('faded', d.label.isFadedOut); + let metaedge = d.label.metaedge; + edgeGroup.select('path.' + Class.Edge.LINE) + .classed('control-dep', metaedge && !metaedge.numRegularEdges); +}; + +} // close module diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/externs.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/externs.ts new file mode 100644 index 0000000000..7c0d168a42 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/externs.ts @@ -0,0 +1,85 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** + * @fileoverview Extern declarations for tensorflow graph visualizer. + * This file contains compiler stubs for external dependencies whos + * implementations are defined at runtime. + */ + +declare module graphlib { + interface GraphOptions { + name?: string; + /** + * Direction for rank nodes. Can be TB, BT, LR, or RL, where T = top, + * B = bottom, L = left, and R = right. + */ + rankdir?: string; + type?: string|number; + /** Number of pixels between each rank in the layout. */ + ranksep?: number; + /** Number of pixels that separate nodes horizontally in the layout. */ + nodesep?: number; + /** Number of pixels that separate edges horizontally in the layout */ + edgesep?: number; + } + + export interface EdgeObject { + v: string; + w: string; + name?: string; + } + + export class Graph { + constructor(opt?: Object); + setNode(name: string, value?: N): void; + hasNode(name: string): boolean; + setEdge(fromName: string, toName: string, value?: E): void; + hasEdge(fromName: string, toName: string): boolean; + edge(fromName: string, toName: string): E; + edge(edgeObject: EdgeObject): E; + removeEdge(v: string, w: string): void; + nodes(): string[]; + node(name: string): N; + removeNode(name: string): void; + setGraph(graphOptions: GraphOptions): void; + graph(): GraphOptions; + nodeCount(): number; + neighbors(name: string): string[]; + successors(name: string): string[]; + predecessors(name: string): string[]; + edges(): EdgeObject[]; + outEdges(name: string): E[]; + inEdges(name: string): E[]; + /** + * Returns those nodes in the graph that have no in-edges. + * Takes O(|V|) time. + */ + sources(): string[]; + /** + * Remove the node with the id v in the graph or do nothing if + * the node is not in the graph. If the node was removed this + * function also removes any incident edges. Returns the graph, + * allowing this to be chained with other functions. Takes O(|E|) time. + */ + removeNode(name: string): Graph; + setParent(name: string, parentName: string): void; + } +} + +/** + * Declaring dagre var used for dagre layout. + */ +declare var dagre: {layout(graph: graphlib.Graph): void;}; diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/graph.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/graph.ts new file mode 100644 index 0000000000..1b0abcfd85 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/graph.ts @@ -0,0 +1,1257 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +module tf.graph { + +/** Delimiter used in node names to denote namespaces. */ +export const NAMESPACE_DELIM = '/'; +export const ROOT_NAME = '__root__'; + +/** Attribute key used for storing attributes that are too large. */ +export const LARGE_ATTRS_KEY = '_too_large_attrs'; +/** + * Maximum allowed size in bytes, before the attribute is considered large + * and filtered out of the graph. + */ +export const LIMIT_ATTR_SIZE = 1024; + +// Separator between the source and the destination name of the edge. +export const EDGE_KEY_DELIM = '--'; + +export enum GraphType {FULL, EMBEDDED, META, SERIES, CORE, SHADOW, BRIDGE, + EDGE}; +export enum NodeType {META, OP, SERIES, BRIDGE, ELLIPSIS}; + +/** Indicates if a node is to be included in the main graph when rendered. */ +export enum InclusionType {INCLUDE, EXCLUDE, UNSPECIFIED}; + +/** Indicates if a series is to be grouped in the graph when rendered. */ +export enum SeriesGroupingType {GROUP, UNGROUP}; + +/** Attribute key reserved for the shapes of the output tensors. */ +const OUTPUT_SHAPES_KEY = '_output_shapes'; + +/** Attribute key reserved for the XLA cluster that an op runs on. */ +const _XLA_CLUSTER_KEY = '_XlaCluster'; + +/** + * A BaseEdge is the label object (in the graphlib sense) for an edge in the + * original, full graph produced after parsing. Subsequent graphs, like those + * which belong to Metanodes, should not use BaseEdge objects, but instead + * contain Metaedges (which in turn may contain any number of BaseEdges). + */ +export interface BaseEdge extends graphlib.EdgeObject { + isControlDependency: boolean; + isReferenceEdge: boolean; + /** The index of the output tensor of the source node. */ + outputTensorIndex: number; +} + +/** + * A SlimGraph is inspired by graphlib.Graph, but having only the functionality + * that we need. + */ +export class SlimGraph { + nodes: { [nodeName: string]: OpNode }; + edges: BaseEdge[]; + + constructor() { + this.nodes = {}; + this.edges = []; + } +} + +export interface NormalizedInput { + name: string; + /** The index of the output tensor of the source node. */ + outputTensorIndex: number; + isControlDependency: boolean; +} + +export interface BuildParams { + enableEmbedding: boolean; + inEmbeddingTypes: string[]; + outEmbeddingTypes: string[]; + refEdges: { [inputEdge: string]: boolean }; +} + +/** + * The most basic information about a node in the hierarchical graph. + */ +export interface Node { + /** The name of the node, used frequently to look up nodes by name. */ + name: string; + /** Which type of node this is. */ + type: NodeType; + /** + * Whether this node is a type that may contain other nodes. Those types + * should extend from GroupNode. + * + * For an OpNode, isGroupNode will be false, even though it may have + * embeddings. These embedding Nodes will have their parentNode set to the + * OpNode. However, embeddings are later rendered as annotations, not as + * children to be made visible on expansion (like a Metanode or SeriesNode). + */ + isGroupNode: boolean; + /** + * The number of nodes this node represents. For OpNodes, this will be 1, and + * for GroupNodes it will be a count of the total number of descendents it + * contains. + */ + cardinality: number; + /** + * The Node which is this Node's parent. This is of type Node and not + * GroupNode because of embeddings, which will have a parent OpNode. + */ + parentNode: Node; + /** Runtime execution stats for this node, if available */ + stats: NodeStats; + /** If the node is to be included or excluded from the main graph when + * rendered. Defaults to UNSPECIFIED, which means that the rendering + * algorithm determines if it will be included or not. Then can be set to + * INCLUDE or EXCLUDE manually by the user. + */ + include: InclusionType; + /** + * Node attributes specify customizable visual aspects of a node and + * application-specific metadata associated with a node. The name + * 'nodeAttributes' is meant to avoid naming-conflicts with the 'attr' in + * subclasses of Node. + */ + nodeAttributes: {[key: string]: any;}; +} + +export type TensorShape = number[]; + +export interface OpNode extends Node { + op: string; + // The device on which the op ran. Null if it is unknown. + device: string; + attr: {key: string, value: any}[]; + inputs: NormalizedInput[]; + inEmbeddings: OpNode[]; + outEmbeddings: OpNode[]; + // The name of the SeriesNode that can contain this node in its series. + // If there is no such node, then this is null. + owningSeries: string; + /** + * Array of tensor shapes. Null if the number of output tensors is unknown, + * otherwise the length will equal the number of output tensors. + * + * Each tensor shape is an array of numbers, or null. Details: + * - null means unknown rank, and therefore entire shape is unknown. + * - [4, 2, 1] means rank-3 tensor of size 4x2x1. + * - [] means a scalar (rank-0 tensor). + * - [1] means rank-1 tensor of size 1 (not the same as scalar). + * - [5, -1, 3] means rank-3 tensor of shape is 5x?x3. The size + * of the middle dimension is unknown (encoded as -1). + */ + outputShapes: TensorShape[]; + // The XLA Cluster on which the op ran. Null if it is unknown. + xlaCluster: string; +} + +export interface BridgeNode extends Node { + /** + * Whether this bridge node represents edges coming into its parent node. + */ + inbound: boolean; +} + +/** + * A node that is used when there are more than the maximum number of allowed + * annotations hanging off of a node. This node represents an ellipsis + * annotation, indicating a number of additional annotations. + */ +export interface EllipsisNode extends Node { + /** + * The number of nodes this ellipsis represents. + */ + numMoreNodes: number; + + /** + * Sets the number of nodes this ellipsis represents and changes the node + * name accordingly. + */ + setNumMoreNodes(numNodes: number); +} + +export interface GroupNode extends Node { + /** + * The metagraph contains nodes and metaedges between the immediate children + * of this group. The node label objects may be other GroupNodes (like + * SeriesNodes and Metanodes) or individual OpNodes. All edge label objects + * are Metaedges, each of which contains references to the original + * BaseEdge(s) from which it was created. + */ + metagraph: graphlib.Graph; + + /** + * The bridgegraph contains only edges which link immediate children of this + * group with nodes outside of the metagraph. As in the metagraph, all edge + * label objects are Metaedges which contain references to the original + * BaseEdge(s) that contribute to it. + * + * For a Metaedge in the bridgegraph, its external endpoint will be the same + * as the metagraph edge from which it came. This is most easily explained + * by example. + * + * Consider an original graph that contains a BaseEdge A/B/C->Z/Y/X. + * + * +-------+ (BaseEdge) +-------+ + * | A/B/C |>----------------->| Z/Y/X | + * +-------+ +-------+ + * + * When we construct the Root's metagraph, it will contain nodes for A and Z, + * and a Metaedge A->Z. The A->Z Metaedge will contain the original BaseEdge + * A/B/C->Z/Y/X in its baseEdgeGraph. The Root's bridgegraph will always be + * empty. + * + * +---+ (Root.metagraph edge) +---+ + * | A |>--------------------------->| Z | + * +---+ +---+ + * + * Now consider the Metanode A. Its metagraph will contain a Metanode for A/B + * and no edges. A's bridgegraph will have one Metaedge from A/B->Z, which + * was derived from the Root's Metaedge A->Z. That Metaedge will contain the + * original BaseEdge in its baseEdgeGraph. + * + * +---------+ + * | A | + * | +---+ | (A.bridgegraph edge) +---+ + * | | B |>---------------------------->| Z | + * | +---+ | +---+ + * +---------+ + * + * Finally, consider the Metanode A/B. Its metagraph will contain a Metanode + * for A/B/C and again no edges. A/B's bridgegraph will have one Metaedge + * from A/B/C->Z, which was derived from A's bridgegraph Metaedge A/B->Z. + * As before, the A/B/C->Z Metaedge will contain the original BaseEdge in its + * baseEdgeGraph. + * + * +---------------+ + * | A | + * | +---------+ | + * | | B | | + * | | +---+ | | (A/B.bridgegraph edge) +---+ + * | | | C |>----------------------------------->| Z | + * | | +---+ | | +---+ + * | +---------+ | + * +---------------+ + * + * Likewise, under the Metanode Z and Z/Y, to compute the bridgegraph, we'll + * end up with Metaedges A->Z/Y and A->Z/Y/X respectively. So the original + * BaseEdge A/B/C->Z/Y/X becomes four different Metaedges in four different + * bridgegraphs: + * + * + A/B->Z in GroupNode A's bridgegraph, + * + A/B/C->Z in GroupNode A/B's bridgegraph, + * + A->Z/Y in GroupNode Z's bridgegraph, and + * + A->Z/Y/X in GroupNode Z/Y's bridgegraph. + * + * Considering any BaseEdge then, if N is the number of path segments in the + * source and M is the number of path segments in the destination, then the + * total number of bridgegraph edges you could create would be (N-1)(M-1). + * + * For this reason, it is computationally expensive to generate all the + * bridgegraphs for all the Metanodes, and instead they should be computed + * on demand as needed. + */ + bridgegraph: graphlib.Graph; + + /** + * Stores how many times each device name appears in its children + * op nodes. Used to color group nodes by devices. + */ + deviceHistogram: {[device: string]: number}; + + /** + * Flag indicating whether this GroupNode's metagraph contains any edges that + * are not control edges. Used to quickly determine how to draw a collapsed + * series (vertically or horizontally). + */ + hasNonControlEdges: boolean; +} + +export interface Metanode extends GroupNode { + depth: number; + templateId: string; + opHistogram: {[op: string]: number}; + getFirstChild(): GroupNode|OpNode; + getRootOp(): OpNode; + /** Return name of all leaves inside a metanode. */ + leaves(): string[]; +} + +export interface SeriesNode extends GroupNode { + hasLoop: boolean; + prefix: string; + suffix: string; + clusterId: number; + ids: number[]; + parent: string; +} + +export class EllipsisNodeImpl implements EllipsisNode { + name: string; + numMoreNodes: number; + stats: NodeStats; + type: NodeType; + isGroupNode: boolean; + cardinality: number; + parentNode: Node; + include: InclusionType; + nodeAttributes: {[key: string]: any;}; + /** + * Constructs a new ellipsis annotation node. + * + * @param numNodes The number of additional annotations this node represents. + */ + constructor(numNodes: number) { + this.type = NodeType.ELLIPSIS; + this.isGroupNode = false; + this.cardinality = 1; + this.parentNode = null; + this.stats = null; + this.setNumMoreNodes(numNodes); + this.include = InclusionType.UNSPECIFIED; + } + + setNumMoreNodes(numNodes: number) { + this.numMoreNodes = numNodes; + this.name = '... ' + numNodes + ' more'; + } +}; + +/** + * A label object for nodes in the full graph and leaf nodes in the render + * graph. + */ +export class OpNodeImpl implements OpNode { + name: string; + op: string; + device: string; + stats: NodeStats; + attr: {key: string, value: any}[]; + inputs: NormalizedInput[]; + type: NodeType; + isGroupNode: boolean; + cardinality: number; + inEmbeddings: OpNode[]; + outEmbeddings: OpNode[]; + parentNode: Node; + include: InclusionType; + owningSeries: string; + outputShapes: TensorShape[]; + nodeAttributes: {[key: string]: any;}; + xlaCluster: string; + + /** + * Constructs a new Op node. + * + * @param rawNode The raw node. + */ + constructor(rawNode: tf.graph.proto.NodeDef) { + this.op = rawNode.op; + this.name = rawNode.name; + this.device = rawNode.device; + this.attr = rawNode.attr; + // An array of normalized inputs that denote the incoming edges to + // the current node. Each input contains the normalized name of the + // source node, whether it has a number part and whether it is a + // control dependency. + this.inputs = normalizeInputs(rawNode.input); + this.outputShapes = extractOutputShapes(rawNode.attr); + this.xlaCluster = extractXlaCluster(rawNode.attr); + // additional properties + this.type = NodeType.OP; + this.isGroupNode = false; + this.cardinality = 1; + this.inEmbeddings = []; + this.outEmbeddings = []; + this.parentNode = null; + this.include = InclusionType.UNSPECIFIED; + this.owningSeries = null; + } +}; + +export function createMetanode(name: string, opt = {}): Metanode { + return new MetanodeImpl(name, opt); +} + +/** + * Joins the information from the stats file (memory, compute time) with the + * graph information. + */ +export function joinStatsInfoWithGraph( + graph: SlimGraph, stats: tf.graph.proto.StepStats, + devicesForStats?: {[device: string]: boolean}): void { + // Reset stats for each node. + _.each(graph.nodes, node => { node.stats = null; }); + + _.each(stats.dev_stats, devStats => { + // Ignore devices that are not selected. + if (devicesForStats && !devicesForStats[devStats.device]) { + return; + } + _.each(devStats.node_stats, nodeStats => { + // Lookup the node in the graph by its original name, e.g. A. If not + // found, lookup by the rewritten name A/(A) in case the name is both + // a namespace and a node name. + let nodeName = nodeStats.node_name in graph.nodes ? nodeStats.node_name : + nodeStats.node_name + + NAMESPACE_DELIM + '(' + nodeStats.node_name + ')'; + + // Couldn't find a matching node. + if (!(nodeName in graph.nodes)) { + return; + } + + // Compute the total bytes used. + let totalBytes = 0; + if (nodeStats.memory) { + _.each(nodeStats.memory, alloc => { + if (alloc.total_bytes) { + if (alloc.total_bytes > 0) { + totalBytes += Number(alloc.total_bytes); + } else { + /* tslint:disable */ + console.log( + 'ignoring negative memory allocation for ' + nodeName); + /* tslint:enable */ + } + } + }); + } + let outputSize: number[][] = null; + if (nodeStats.output) { + outputSize = _.map(nodeStats.output, output => { + return _.map(output.tensor_description.shape.dim, + dim => Number(dim.size)); + }); + } + graph.nodes[nodeName].device = devStats.device; + if (graph.nodes[nodeName].stats == null) { + graph.nodes[nodeName].stats = new NodeStats(outputSize); + } + graph.nodes[nodeName].stats.addBytesAllocation(totalBytes); + if (nodeStats.all_end_rel_micros) { + if (nodeStats.all_end_rel_micros > 0) { + graph.nodes[nodeName].stats.addExecutionTime( + nodeStats.all_start_micros, + nodeStats.all_start_micros + nodeStats.all_end_rel_micros); + } else { + /* tslint:disable */ + console.log('ignoring negative runtime for ' + nodeName); + /* tslint:enable */ + } + } + }); + }); +} + +/** + * Execution stats for the node. + */ +export class NodeStats { + constructor(outputSize: number[][]) { this.outputSize = outputSize; } + + /** + * Add the start and end time for a particular kernel execution of this op. + * Ops can have multiple kernel executions within the same session run. + */ + addExecutionTime(startTime: number, endTime: number) { + if (this.startTime != null) { + this.startTime = Math.min(this.startTime, startTime); + } else { + this.startTime = startTime; + } + if (this.endTime != null) { + this.endTime = Math.max(this.endTime, endTime); + } else { + this.endTime = endTime; + } + } + + /** + * Add the bytes allocated for a particular kernel execution of this op. + * Ops can have multiple kernel executions within the same session run. + */ + addBytesAllocation(totalBytes: number) { + if (this.totalBytes != null) { + this.totalBytes = Math.max(this.totalBytes, totalBytes); + } else { + this.totalBytes = totalBytes; + } + } + + /** + * Absolute start time for the very first kernel execution of this op. + */ + startTime: number; + /** + * Absolute end time for the very last kernel execution of this op. + */ + endTime: number; + /** + * Total number of bytes used for the node. Sum of all children + * if it is a Group node. + */ + totalBytes = 0; + + /** + * The shape of each output tensors, if there are any. + * Empty if it is a Group node. + */ + outputSize: number[][]; + + /** + * Combines the specified stats with the current stats. + * Modifies the current object. This method is used to + * compute aggregate stats for group nodes. + */ + combine(stats: NodeStats): void { + if (stats.totalBytes != null) { + this.totalBytes += stats.totalBytes; + } + if (stats.getTotalMicros() != null) { + this.addExecutionTime(stats.startTime, stats.endTime); + } + } + + /** + * Total number of compute time in microseconds used for the node. + * Sum of all children if it is a Group node. Null if it is unknown. + * This method can not be scaffolded under a getter attribute because + * ECMAScript 5 does not support getter attributes. + */ + getTotalMicros(): number { + if (this.startTime == null || this.endTime == null) { + return null; + } + return this.endTime - this.startTime; + } +} + +export class MetanodeImpl implements Metanode { + name: string; + stats: NodeStats; + type: NodeType; + depth: number; + isGroupNode: boolean; + cardinality: number; + metagraph: graphlib.Graph; + bridgegraph: graphlib.Graph; + templateId: string; + opHistogram: {[op: string]: number}; + deviceHistogram: {[op: string]: number}; + parentNode: Node; + hasNonControlEdges: boolean; + include: InclusionType; + nodeAttributes: {[key: string]: any;}; + + /** A label object for meta-nodes in the graph hierarchy */ + constructor(name: string, opt = {}) { + this.name = name; + this.type = NodeType.META; + /** number of levels under this group */ + this.depth = 1; + this.isGroupNode = true; + /** # of leaf nodes (including embedded ones) */ + this.cardinality = 0; + /** graph contains metanodes, nodes, edges + * and metaedges for main items within this metanode + */ + this.metagraph = + createGraph(name, GraphType.META, opt); + /** bridgegraph must be constructed lazily-see hierarchy.getBridgegraph() */ + this.bridgegraph = null; + /** + * A dictionary that count ops type of nodes in this metanode + * (op type => count). + */ + this.opHistogram = {}; + this.deviceHistogram = {}; + /** unique id for a metanode of similar subgraph */ + this.templateId = null; + /** Metanode which contains this node, if any */ + this.parentNode = null; + this.hasNonControlEdges = false; + this.include = InclusionType.UNSPECIFIED; + } + + getFirstChild(): GroupNode|OpNode { + return this.metagraph.node(this.metagraph.nodes()[0]); + } + + /** + * Returns the op node associated with the metanode. + * For example, if the metanode is 'sgd', the associated + * op node is sgd/(sgd). + */ + getRootOp(): OpNode { + let nameSplit = this.name.split('/'); + let rootOpName = this.name + '/(' + nameSplit[nameSplit.length - 1] + ')'; + return this.metagraph.node(rootOpName); + } + + /** + * Return an array of the names of all the leaves (non-GroupNodes) inside + * this metanode. This performs a breadth-first search of the tree, so + * immediate child leaves will appear earlier in the output array than + * descendant leaves. + */ + leaves(): string[] { + let leaves = []; + let queue = [ this]; + let metagraph; // Defined here due to a limitation of ES6->5 compilation. + while (queue.length) { + let node = queue.shift(); + if (node.isGroupNode) { + metagraph = ( node).metagraph; + _.each(metagraph.nodes(), name => queue.push(metagraph.node(name))); + } else { + leaves.push(node.name); + } + } + return leaves; + } +}; + +export interface Metaedge extends graphlib.EdgeObject { + + /** + * Stores the original BaseEdges represented by this Metaedge. + */ + baseEdgeList: BaseEdge[]; + + /** + * Whether this edge represents a relationship that is inbound (or outbound) + * to the object which contains this information. For example, in a Metanode's + * bridgegraph, each edge connects an immediate child to something outside + * the Metanode. If the destination of the edge is inside the Metanode, then + * its inbound property should be true. If the destination is outside the + * Metanode, then its inbound property should be false. + * + * The property is optional because not all edges can be described as + * inbound/outbound. For example, in a Metanode's metagraph, all of the edges + * connect immediate children of the Metanode. None should have an inbound + * property, or they should be null/undefined. + */ + inbound?: boolean; + + /** + * Number of regular edges (not control dependency edges). + */ + numRegularEdges: number; + + /** + * Number of control dependency edges. + */ + numControlEdges: number; + + /** + * Number of reference edges, which is an edge to an operation + * that takes a reference to its input and changes its value. + */ + numRefEdges: number; + + /** + * Total size (number of units) of all the tensors flowing through this edge. + */ + totalSize: number; + + addBaseEdge(edge: BaseEdge, h: hierarchy.Hierarchy): void; +} + +export function createMetaedge(v: string, w: string): Metaedge { + return new MetaedgeImpl(v, w); +} + +/** + * A label object for edges between metanodes of subgraphs in the render graph. + */ +export class MetaedgeImpl implements Metaedge { + v: string; + w: string; + baseEdgeList: BaseEdge[]; + inbound: boolean; + numRegularEdges: number; + numControlEdges: number; + numRefEdges: number; + totalSize: number; + + constructor(v: string, w: string) { + this.v = v; + this.w = w; + this.baseEdgeList = []; + this.inbound = null; + this.numRegularEdges = 0; + this.numControlEdges = 0; + this.numRefEdges = 0; + this.totalSize = 0; + } + + addBaseEdge(edge: BaseEdge, h: hierarchy.Hierarchy): void { + this.baseEdgeList.push(edge); + if (edge.isControlDependency) { + this.numControlEdges += 1; + } else { + this.numRegularEdges += 1; + } + if (edge.isReferenceEdge) { + this.numRefEdges += 1; + } + // Compute the size of the tensor flowing through this + // base edge. + this.totalSize += MetaedgeImpl.computeSizeOfEdge(edge, h); + h.maxMetaEdgeSize = Math.max(h.maxMetaEdgeSize, this.totalSize); + } + + private static computeSizeOfEdge(edge: BaseEdge, h: hierarchy.Hierarchy): + number { + let opNode = h.node(edge.v); + if (opNode.outputShapes == null) { + // No shape information. Asssume a single number. This gives + // a lower bound for the total size. + return 1; + } + h.hasShapeInfo = true; + // Sum the sizes of all output tensors. + return _(opNode.outputShapes).map(shape => { + // If the shape is unknown, treat it as 1 when computing + // total size. This gives a lower bound for the total size. + if (shape == null) { + return 1; + } + // Multiply all shapes to get the total size of the tensor. + // E.g. The total size of [4, 2, 1] is 4 * 2 * 1. + return _(shape).reduce((accumulated, currSize) => { + // If this particular dimension is unknown, treat + // it as 1 when computing total size. This gives a lower bound + // for the total size. + if (currSize === -1) { + currSize = 1; + } + return accumulated * currSize; + }, 1); + }).sum(); + } +} + +export function createSeriesNode(prefix: string, suffix: string, + parent: string, clusterId: number, name: string): SeriesNode { + return new SeriesNodeImpl(prefix, suffix, parent, clusterId, name); +} + +export function getSeriesNodeName(prefix: string, suffix: string, + parent: string, startId?: number, endId?: number): string { + let numRepresentation = + (typeof startId !== 'undefined' && typeof endId !== 'undefined') ? + '[' + startId + '-' + endId + ']' : + '#'; + let pattern = prefix + numRepresentation + suffix; + return (parent ? parent + '/' : '') + pattern; +} + +class SeriesNodeImpl implements SeriesNode { + name: string; + type: NodeType; + stats: NodeStats; + hasLoop: boolean; + prefix: string; + suffix: string; + clusterId: number; + ids: number[]; + parent: string; + isGroupNode: boolean; + cardinality: number; + metagraph: graphlib.Graph; + bridgegraph: graphlib.Graph; + parentNode: Node; + deviceHistogram: {[op: string]: number}; + hasNonControlEdges: boolean; + include: InclusionType; + nodeAttributes: {[key: string]: any;}; + + constructor(prefix: string, suffix: string, parent: string, + clusterId: number, name: string) { + this.name = name || getSeriesNodeName(prefix, suffix, parent); + this.type = NodeType.SERIES; + this.hasLoop = false; + this.prefix = prefix; + this.suffix = suffix; + this.clusterId = clusterId; + this.ids = []; + this.parent = parent; + this.isGroupNode = true; + this.cardinality = 0; + this.metagraph = createGraph(name, GraphType.SERIES); + // bridgegraph must be constructed lazily-see hierarchy.getBridgegraph() + this.bridgegraph = null; + this.parentNode = null; + this.deviceHistogram = {}; + this.hasNonControlEdges = false; + this.include = InclusionType.UNSPECIFIED; + } +} + +/** + * Extracts the shapes of the output tensors from the attr property in the + * node proto. + */ +// tslint:disable-next-line:no-any +function extractOutputShapes(attr: Array<{key: string, value: any}>): + TensorShape[] { + let result = null; + // We don't know anything about the output tensors. + if (!attr) { + return null; + } + for (let i = 0; i < attr.length; i++) { + let {key, value} = attr[i]; + if (key === OUTPUT_SHAPES_KEY) { + if (!value.list.shape) { + // The OUTPUT_SHAPES_KEY lacks a value. We know nothing about the shape. + return null; + } + + // Map all output tensors into array of numbers denoting their shape. + let result = value.list.shape.map(shape => { + if (shape.unknown_rank) { + // This output tensor is of unknown rank. We don't know if it is a + // scalar, or a tensor, or of what shape it is. + return null; + } + if (shape.dim == null || + (shape.dim.length === 1 && shape.dim[0].size == null)) { + // This output tensor is a scalar. + return []; + } + // This output tensor has a known rank. Map each dimension size + // into a number. + return shape.dim.map(dim => { + // Size can be -1 if this particular dimension is unknown. + return dim.size; + }); + }); + // Since we already processed it, remove the entry from the attribute + // list (saves memory). + attr.splice(i, 1); + return result; + } + } + // We didn't find OUTPUT_SHAPES_KEY in attributes, so we don't know anything + // about the output tensors. + return null; +} + +/** + * Extracts the XLA Cluster that an op runs on from the attrs of the OpNode. + * @param attr The attr property. + * @return A string that is the name of the cluster. Or null if it could not be + * determined. + */ +// tslint:disable-next-line:no-any +function extractXlaCluster(attr: Array<{key: string, value: any}>): string| + null { + if (!attr) { + return null; + } + + // Find the attribute for XLA cluster if there is one. + for (let i = 0; i < attr.length; i++) { + if (attr[i].key === _XLA_CLUSTER_KEY) { + return attr[i].value['s'] || null; + } + } + return null; +} + +/** + * Normalizes the inputs and extracts associated metadata: + * 1) Inputs can contain a colon followed by a number at the end + * (e.g. inputName:1) and we remove this from the input name, and take note + * that the input was numbered. + * 2) Control dependency inputs contain caret at the beginning and we + * remove this and annotate the edge as a control dependency. + * @param inputs Array of unnormalized names of input nodes. + */ +function normalizeInputs(inputs: string[]): NormalizedInput[] { + let normalizedInputs: NormalizedInput[] = []; + _.each(inputs, inputName => { + let start = inputName[0] === '^'; + let colon = inputName.lastIndexOf(':'); + let end = colon !== -1 && + inputName.length - colon > 1 && + !(/\D/).test(inputName.substring(colon + 1)) ? + colon : inputName.length; + let name = inputName.substring(start ? 1 : 0, end); + if (normalizedInputs.length === 0 || + name !== normalizedInputs[normalizedInputs.length - 1].name) { + normalizedInputs.push({ + name: name, + outputTensorIndex: + end === inputName.length ? 0 : Number(inputName.slice(colon + 1)), + isControlDependency: start + }); + } + }); + return normalizedInputs; +} + +function addEdgeToGraph( + graph: SlimGraph, inputName: string, outputNode: OpNode, + input: NormalizedInput, params: BuildParams, index: number) { + // Don't allow loops in the graph. + if (inputName === outputNode.name) { + return; + } + // Check if this op type and input number corresponds to a + // reference edge using the refEdges dictionary in the params. + let isRefEdge = params.refEdges[outputNode.op + ' ' + index] === true; + graph.edges.push({ + v: inputName, + w: outputNode.name, + outputTensorIndex: input.outputTensorIndex, + isControlDependency: input.isControlDependency, + isReferenceEdge: isRefEdge + }); +} + +export function build( + rawNodes: tf.graph.proto.NodeDef[], params: BuildParams, + tracker: ProgressTracker): Promise { + /** + * A dictionary that maps each in-embedding node name to the node + * object. + */ + let inEmbedding: {[nodeName: string]: OpNode} = {}; + /** + * A dictionary that maps each out-embedding node name to the node + * object. + */ + let outEmbedding: {[nodeName: string]: OpNode} = {}; + /** + * A dictionary that maps each node name to an array of the node's + * out-embedding node label objects. + */ + let outEmbeddings: {[inputName: string]: OpNode[]} = {}; + let isInEmbeddedPred = getEmbedPredicate(params.inEmbeddingTypes); + let isOutEmbeddedPred = getEmbedPredicate(params.outEmbeddingTypes); + let embeddingNodeNames: string[] = []; + /** + * A list of all the non-embedding node names which appear in the processed + * list of raw nodes. Here we pre-allocate enough room for all the rawNodes, + * even though there will some number of embeddings. The excess array length + * is spliced off later. + * + * Experimentation shows that around 30% of the array will go unused, and + * even for very large networks that amounts to less than 10k spaces. + */ + let nodeNames = new Array(rawNodes.length); + + return tf.graph.util + .runAsyncTask( + 'Normalizing names', 30, + () => { + let opNodes = new Array(rawNodes.length); + let index = 0; + _.each(rawNodes, rawNode => { + let opNode = new OpNodeImpl(rawNode); + if (isInEmbeddedPred(opNode)) { + embeddingNodeNames.push(opNode.name); + inEmbedding[opNode.name] = opNode; + return; + } + + if (isOutEmbeddedPred(opNode)) { + embeddingNodeNames.push(opNode.name); + outEmbedding[opNode.name] = opNode; + _.each(opNode.inputs, input => { + let inputName = input.name; + outEmbeddings[inputName] = outEmbeddings[inputName] || []; + outEmbeddings[inputName].push(opNode); + }); + return; + } + // The node is not an embedding, so add it to the names and nodes + // lists. + opNodes[index] = opNode; + nodeNames[index] = opNode.name; + index++; + }); + opNodes.splice(index); + nodeNames.splice(index); + return opNodes; + }, + tracker) + .then((opNodes) => { + // Create the graph data structure from the graphlib library. + return tf.graph.util.runAsyncTask( + 'Building the data structure', 70, () => { + let normalizedNameDict = + mapStrictHierarchy(nodeNames, embeddingNodeNames); + let graph = new SlimGraph; + + // Add the nodes to the graph. + _.each(opNodes, opNode => { + let normalizedName = + normalizedNameDict[opNode.name] || opNode.name; + graph.nodes[normalizedName] = opNode; + // Check if the node has out-embeddings. If yes, add them to the + // node. + if (opNode.name in outEmbeddings) { + opNode.outEmbeddings = outEmbeddings[opNode.name]; + // Normalize the names of the out-embeddings. + _.each(opNode.outEmbeddings, node => { + node.name = normalizedNameDict[node.name] || node.name; + }); + } + // Update the name of the node. + opNode.name = normalizedName; + }); + + // Visit each node's inputs to add the edges to the graph. If the + // input + // is an in-embedding, then add it to the node's in-embeddings + // instead. + _.each(opNodes, opNode => { + _.each(opNode.inputs, (input, i) => { + let inputName = input.name; + if (inputName in inEmbedding) { + let inEmbedNode = inEmbedding[inputName]; + opNode.inEmbeddings.push(inEmbedNode); + // Move the inputs of the in-embedding node into incoming + // edges of + // the main node. E.g. the control dependency of a constant + // node + // should be moved to the op node where the constant is + // embedded. + for (let embedInput of inEmbedNode.inputs) { + addEdgeToGraph( + graph, normalizedNameDict[embedInput.name] || + embedInput.name, + opNode, embedInput, params, i); + } + } else if (inputName in outEmbedding) { + // Move the inputs of the out-embedding node into inputs of + // the main node where the out-embedding points to. + let outEmbedNode = outEmbedding[inputName]; + for (let embedInput of outEmbedNode.inputs) { + addEdgeToGraph( + graph, normalizedNameDict[embedInput.name] || + embedInput.name, + opNode, input, params, i); + } + } else { + addEdgeToGraph( + graph, normalizedNameDict[inputName] || inputName, + opNode, input, params, i); + } + }); + }); + + // Normalize the names of in-embeddings. + _.each(inEmbedding, (node, name) => { + node.name = normalizedNameDict[node.name] || node.name; + }); + + return graph; + }, tracker); + }); +}; + +/** + * Create a new graphlib.Graph() instance with default parameters + */ +export function createGraph(name: string, type, opt = {}): + graphlib.Graph { + let graph = new graphlib.Graph(opt); + graph.setGraph({ + name: name, + rankdir: 'BT', // BT,TB,LR,RL + type: type + }); + return graph; +}; + +/** + * Create a predicate for checking whether a node should be embedded based on + * the specified types. + */ +function getEmbedPredicate(types: string[]) { + return function(node: OpNode) { + // check types + for (let i = 0; i < types.length; i++) { + let regExp = new RegExp(types[i]); + if (node.op.match(regExp)) { return true; } + } + return false; + }; +}; + +/** + * Returns a strict node name (name => name/(name)) to avoid conflicts + * where the node name is also a namespace. + */ +export function getStrictName(name: string): string { + let parts = name.split(NAMESPACE_DELIM); + return name + NAMESPACE_DELIM + '(' + parts[parts.length - 1] + ')'; +} + +/** + * For each op node (embedding or non-embedding), rename it if there is a + * non-embedding node under its namespace. For example, assume node name 'A'. + * If there is a non-embedding node under its namespace (e.g. 'A/B'), 'A' will + * be renamed to 'A/(A)'. Then the namespace 'A' will contain 2 nodes: '(A)' + * and 'B'. If all the nodes under 'A' are embedding nodes (e.g. constant and + * summary), keep 'A' as an Op node and don't create a namespace. + * + * @param nodeNames An array of regular (non-embedding) node names. + * @param embeddingNodeNames An array of embedding node names. + * @return Dictionary object mapping names that need to be renamed to + * new names. + */ +function mapStrictHierarchy(nodeNames: string[], + embeddingNodeNames: string[]): {[oldName: string]: string} { + /** Dictionary that maps the old new to the new name */ + let newNameDictionary: {[oldName: string]: string} = {}; + /** Set used to store all namespaces. */ + let namespaceSet: {[namespace: string]: boolean} = {}; + // sort the nodes to make prefix check faster + nodeNames.sort(); + // look for nodes with a prefix a,a/b -> a/(a),a/b + for (let i = 0; i < nodeNames.length - 1; ++i) { + let a = nodeNames[i]; + // Get all the parent namespaces of the current node + // and add them in the namespace set. + _.each(getHierarchicalPath(a).slice(0, -1), ns => { + namespaceSet[ns] = true; + }); + for (let j = i + 1; j < nodeNames.length; ++j) { + let b = nodeNames[j]; + if (_.startsWith(b, a)) { + if (b.length > a.length && b.charAt(a.length) === NAMESPACE_DELIM) { + newNameDictionary[a] = getStrictName(a); + break; + } + } else { + break; + } + } + } + // Go through all the embedding node names and rename them in case they + // collide with namespaces. + _.each(embeddingNodeNames, embeddingName => { + if (embeddingName in namespaceSet) { + // Rename to follow strict hierarchy. + newNameDictionary[embeddingName] = getStrictName(embeddingName); + } + }); + return newNameDictionary; +}; + +/** + * Returns a list of the degrees of each node in the graph. + */ +function degreeSequence(graph: graphlib.Graph): number[] { + let degrees = graph.nodes().map(function(name) { + return graph.neighbors(name).length; + }); + degrees.sort(); + return degrees; +}; + +/** + * Returns if the degree sequence of the two graphs is the same. + */ +export function hasSimilarDegreeSequence(graph1: graphlib.Graph, + graph2: graphlib.Graph): boolean { + let dg1 = degreeSequence(graph1); + let dg2 = degreeSequence(graph2); + + for (let i = 0; i < dg1.length; i++) { + if (dg1[i] !== dg2[i]) { + return false; + } + } + return true; +}; + +/** + * Returns the hierarchical path of the current node, based on the node's name. + * For example, if the name is 'a/b/c', the returned path is + * ['a', 'a/b', 'a/b/c']. + */ +export function getHierarchicalPath(name: string, + seriesNames?: { [name: string]: string }): string[] { + let path: string[] = []; + let i = name.indexOf(NAMESPACE_DELIM); + // Push all parent portions of the path. + while (i >= 0) { + path.push(name.substring(0, i)); + i = name.indexOf(NAMESPACE_DELIM, i + 1); + } + // If the node's path is under a series, then add the series node name to the + // hierarchical path as the parent of the leaf. + if (seriesNames) { + let seriesName = seriesNames[name]; + if (seriesName) { + path.push(seriesName); + } + } + // Push the leaf of the path. + path.push(name); + return path; +}; + +/** + * Returns the string for the node inclusion toggle button, dependant + * on the provided current InclusionType. + */ +export function getIncludeNodeButtonString(include: InclusionType) { + if (include === tf.graph.InclusionType.EXCLUDE) { + return 'Add to main graph'; + } else { + return 'Remove from main graph'; + } +}; + +/** + * Returns the string for the series node grouping toggle button, dependant + * on the provided current SeriesGroupingType. + */ +export function getGroupSeriesNodeButtonString(group: SeriesGroupingType) { + if (group === tf.graph.SeriesGroupingType.GROUP) { + return 'Ungroup this series of nodes'; + } else { + return 'Group this series of nodes'; + } +}; + +/** + * Toggle the node series grouping option in the provided map, setting it + * to ungroup if the series is not already in the map. + */ +export function toggleNodeSeriesGroup( + map: { [name: string]: tf.graph.SeriesGroupingType }, name: string) { + if (!(name in map) || map[name] === tf.graph.SeriesGroupingType.GROUP) { + map[name] = tf.graph.SeriesGroupingType.UNGROUP; + } else { + map[name] = tf.graph.SeriesGroupingType.GROUP; + } +}; + +} // close module tf.graph diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/hierarchy.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/hierarchy.ts new file mode 100644 index 0000000000..889607ac50 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/hierarchy.ts @@ -0,0 +1,807 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/** + * Package for the Graph Hierarchy for TensorFlow graph. + */ +module tf.graph.hierarchy { + +/** + * Class used as output for getPredecessors and getSuccessors methods + */ +export interface Edges { + control: Metaedge[]; + regular: Metaedge[]; +} + +export interface Hierarchy { + root: Metanode; + templates: {[templateId: string]: string[]}; + /** List of all device names */ + devices: string[]; + /** List of all XLA cluster names */ + xlaClusters: string[]; + /** True if at least one tensor in the graph has shape information */ + hasShapeInfo: boolean; + /** The maximum size across all meta edges. Used for scaling thickness. */ + maxMetaEdgeSize: number; + getNodeMap(): {[nodeName: string]: GroupNode|OpNode}; + node(name: string): GroupNode|OpNode; + setNode(name: string, node: GroupNode|OpNode): void; + getBridgegraph(nodeName: string): graphlib.Graph; + getPredecessors(nodeName: string): Edges; + getSuccessors(nodeName: string): Edges; + getTopologicalOrdering(nodeName: string): { [childName: string]: number }; + getTemplateIndex(): (string) => number; +} + +/** + * Class for the Graph Hierarchy for TensorFlow graph. + */ +class HierarchyImpl implements Hierarchy { + root: Metanode; + templates: {[templateId: string]: string[]}; + private index: {[nodeName: string]: GroupNode|OpNode}; + devices: string[]; + xlaClusters: string[]; + hasShapeInfo = false; + maxMetaEdgeSize = 1; + orderings: { [nodeName: string]: { [childName: string]: number } }; + + constructor() { + this.root = createMetanode(ROOT_NAME, {compound: true}); + this.templates = null; + this.devices = null; + /** + * @type {Object} Dictionary object that maps node name to the node + * (could be op-node, metanode, or series-node) + */ + this.index = {}; + this.index[ROOT_NAME] = this.root; + this.orderings = {}; + } + + getNodeMap(): {[nodeName: string]: GroupNode|OpNode} { + return this.index; + } + + node(name: string): GroupNode|OpNode { + return this.index[name]; + } + + setNode(name: string, node: GroupNode|OpNode): void { + this.index[name] = node; + } + + /** + * Given the name of a node in this hierarchy, get its bridgegraph, creating + * it on the fly if necessary. If the node is not a GroupNode, then this + * method returns null. If the provided name does not map to a node in the + * hierarchy, an error will be thrown. + */ + getBridgegraph(nodeName: string): graphlib.Graph { + let node = this.index[nodeName]; + if (!node) { + throw Error('Could not find node in hierarchy: ' + nodeName); + } + if (!('metagraph' in node)) { + return null; + } + let groupNode = node; + if (groupNode.bridgegraph) { + return groupNode.bridgegraph; + } + let bridgegraph = groupNode.bridgegraph = + createGraph( + 'BRIDGEGRAPH', GraphType.BRIDGE); + if (!node.parentNode || !('metagraph' in node.parentNode)) { + return bridgegraph; + } + + let parentNode = node.parentNode; + let parentMetagraph = parentNode.metagraph; + let parentBridgegraph = this.getBridgegraph(parentNode.name); + + // For each of the parent node's two Metaedge containing graphs, process + // each Metaedge involving this node. + _.each([parentMetagraph, parentBridgegraph], parentGraph => { + _(parentGraph.edges()) + .filter(e => e.v === nodeName || e.w === nodeName) + .each(parentEdgeObj => { + + let inbound = parentEdgeObj.w === nodeName; + let parentMetaedge = parentGraph.edge(parentEdgeObj); + + // The parent's Metaedge represents some number of underlying + // BaseEdges from the original full graph. For each of those, we need + // to determine which immediate child is involved and make sure + // there's a Metaedge in the bridgegraph that covers it. + _.each(parentMetaedge.baseEdgeList, baseEdge => { + + // Based on the direction, figure out which is the descendant node + // and which is the 'other' node (sibling of parent or ancestor). + let [descendantName, otherName] = + inbound ? + [baseEdge.w, parentEdgeObj.v] : + [baseEdge.v, parentEdgeObj.w]; + + // Determine the immediate child containing this descendant node. + let childName = this.getChildName(nodeName, descendantName); + + // Look for an existing Metaedge in the bridgegraph (or create a + // new one) that covers the relationship between child and other. + let bridgeEdgeObj = { + v: inbound ? otherName : childName, + w: inbound ? childName : otherName, + }; + let bridgeMetaedge = bridgegraph.edge(bridgeEdgeObj); + if (!bridgeMetaedge) { + bridgeMetaedge = createMetaedge(bridgeEdgeObj.v, bridgeEdgeObj.w); + bridgeMetaedge.inbound = inbound; + bridgegraph.setEdge(bridgeEdgeObj.v, bridgeEdgeObj.w, + bridgeMetaedge); + } + + // Copy the BaseEdge from the parent's Metaedge into this + // bridgegraph Metaedge. + bridgeMetaedge.addBaseEdge(baseEdge, this); + }); + }) + .value(); // force lodash chain execution. + }); + + return bridgegraph; + } + + /** + * Utility function for determining the name of the immediate child under a + * node for a given descendant path. If the descendant corresponds to no + * immediate child, an error is thrown. + */ + getChildName(nodeName: string, descendantName: string): string { + // Walk up the hierarchy from the descendant to find the child. + let currentNode: Node = this.index[descendantName]; + while (currentNode) { + if (currentNode.parentNode && currentNode.parentNode.name === nodeName) { + return currentNode.name; + } + currentNode = currentNode.parentNode; + } + throw Error( + 'Could not find immediate child for descendant: ' + descendantName); + }; + + /** Given the name of a node, return its incoming metaedges. */ + getPredecessors(nodeName: string): Edges { + let node = this.index[nodeName]; + if (!node) { + throw Error('Could not find node with name: ' + nodeName); + } + + let predecessors = this.getOneWayEdges(node, true); + // Add embedded predecessors, such as constants. + if (!node.isGroupNode) { + _.each((node).inEmbeddings, embeddedNode => { + _.each((node).inputs, input => { + if (input.name === embeddedNode.name) { + // Make a new metaedge holding the edge between the + // node and the in-embedding. + let metaedge = new MetaedgeImpl(embeddedNode.name, nodeName); + metaedge.addBaseEdge( + { + isControlDependency: input.isControlDependency, + outputTensorIndex: input.outputTensorIndex, + isReferenceEdge: false, + v: embeddedNode.name, + w: nodeName + }, + this); + predecessors.regular.push(metaedge); + } + }); + }); + } + return predecessors; + } + + /** + * Given the name of a node, return its outgoing metaedges. + * + * This is the inverse of getPredecessors(). See that method's documentation + * for an in-depth example. + */ + getSuccessors(nodeName: string): Edges { + let node = this.index[nodeName]; + if (!node) { + throw Error('Could not find node with name: ' + nodeName); + } + + let successors = this.getOneWayEdges(node, false); + + // Add embedded successors, such as summaries. + if (!node.isGroupNode) { + _.each((node).outEmbeddings, embeddedNode => { + _.each(embeddedNode.inputs, input => { + if (input.name === nodeName) { + // Make a new metaedge holding the edge between the + // node and the out-embedding. + let metaedge = new MetaedgeImpl(nodeName, embeddedNode.name); + metaedge.addBaseEdge( + { + isControlDependency: input.isControlDependency, + outputTensorIndex: input.outputTensorIndex, + isReferenceEdge: false, + v: nodeName, + w: embeddedNode.name + }, + this); + successors.regular.push(metaedge); + } + }); + }); + } + return successors; + } + + /** Helper method for getPredecessors and getSuccessors */ + getOneWayEdges(node: GroupNode|OpNode, inEdges: boolean) { + let edges: Edges = {control: [], regular: []}; + // A node with no parent cannot have any edges. + if (!node.parentNode || !node.parentNode.isGroupNode) { + return edges; + } + let parentNode = node.parentNode; + let metagraph = parentNode.metagraph; + let bridgegraph = this.getBridgegraph(parentNode.name); + findEdgeTargetsInGraph(metagraph, node, inEdges, edges); + findEdgeTargetsInGraph(bridgegraph, node, inEdges, edges); + return edges; + } + + /** + * For a given GroupNode, get or calculate an object which describes a + * topological ordering of child nodes within that GroupNode's metagraph. + * + * This ordering is used when rendering bridge control edges which are + * sometimes backwards relative to the dataflow. + * + * For example, say we have a graph with two edges A->B and A->C, and we're + * interested in the ordering under ROOT. In this case, any of the following + * would be legitimate return values: + * + * - { 'A': 0, 'B': 1, 'C': 2 } -- most likely + * - { 'A': 0, 'B': 2, 'C': 1 } -- less likely + * - { 'A': 12, 'B': 100, 'C': 99 } -- unlikely, but still OK + * + * The algorithm does not guarantee that all numbers from 0-N (where N is + * the number of nodes) appear exactly once. Rather it guarantees that if + * there is a path between two nodes, the earlier one will have a lower + * number in the ordering hash. + * + * When generating the ordering, we ignore control Metaedges (those which + * represent only BaseEdges that have isControlDependency set to true). + * + * If there is no node with the specified name, an error is thrown. If the + * node with the specified name is not a group node, null is returned. + */ + getTopologicalOrdering(nodeName: string): { [childName: string]: number } { + let node = this.index[nodeName]; + if (!node) { + throw Error('Could not find node with name: ' + nodeName); + } + if (!node.isGroupNode) { + return null; + } + if (nodeName in this.orderings) { + return this.orderings[nodeName]; + } + + // Mapping of a child node names to lists of their successors. + let successors: { [childName: string]: string[] } = {}; + + // Set of node names which have appeared as a destination. + let destinations: { [childName: string]: boolean } = {}; + + let metagraph = ( node).metagraph; + _.each(metagraph.edges(), (e: graphlib.EdgeObject) => { + if (!metagraph.edge(e).numRegularEdges) { + return; // Skip control edges. + } + + // Keep track of successors and destinations. + if (!(e.v in successors)) { + successors[e.v] = []; + } + successors[e.v].push(e.w); + destinations[e.w] = true; + }); + + // Seed the queue with true sources (those that are not destinations). + let queue: string[] = + _.difference(_.keys(successors), _.keys(destinations)); + + // Produce an ordering by traversing the graph breadth first. + let ordering = this.orderings[nodeName] = {}; + let index = 0; + while (queue.length) { + let childName = queue.shift(); + ordering[childName] = index++; + _.each(successors[childName], succName => queue.push(succName)); + delete successors[childName]; // Prevent cycles from infinite looping. + } + return ordering; + } + + /** + * Returns a d3 Ordinal function that can be used to look up the index of + * a node based on its template id. + */ + getTemplateIndex(): (string) => number { + let templateNames = d3.keys(this.templates); + let templateIndex = d3.scaleOrdinal() + .domain(templateNames) + .range(d3.range(0, templateNames.length)); + return (templateId: string) => templateIndex(templateId); + } +} + +/** + * Internal utility function - given a graph (should be either a metagraph or a + * bridgegraph) and a node which is known to be in that graph, determine + * the other ends of edges that involve that node in the direction specified + * by whether it's inbound. + * + * For example if you wanted to find the predecessors of a node, you'd call + * this method for the parent's metagraph and bridgegraph, specifying inbound + * as true (look at the source of inbound edges to the specified node). + * + * Discovered target names are appended to the targets array. + */ +function findEdgeTargetsInGraph( + graph: graphlib.Graph, + node: Node, inbound: boolean, targets: Edges): void { + let edges = inbound ? graph.inEdges(node.name) : graph.outEdges(node.name); + _.each(edges, e => { + let metaedge = graph.edge(e); + let targetList = + metaedge.numRegularEdges ? targets.regular : targets.control; + targetList.push(metaedge); + }); +} + +export interface HierarchyParams { + verifyTemplate: boolean; + seriesNodeMinSize: number; + seriesMap: { [name: string]: tf.graph.SeriesGroupingType }; +} + +/** + * @param graph The raw graph. + * @param params Parameters used when building a hierarchy. + */ +export function build(graph: tf.graph.SlimGraph, params: HierarchyParams, + tracker: ProgressTracker): Promise { + let h = new HierarchyImpl(); + let seriesNames: { [name: string]: string } = {}; + return tf.graph.util + .runAsyncTask( + 'Adding nodes', 20, + () => { + // Get all the possible device and XLA cluster names. + let deviceNames = {}; + let xlaClusterNames = {}; + _.each(graph.nodes, (node, nodeName) => { + if (node.device) { + deviceNames[node.device] = true; + } + + if (node.xlaCluster) { + xlaClusterNames[node.xlaCluster] = true; + } + }); + + h.devices = _.keys(deviceNames); + h.xlaClusters = _.keys(xlaClusterNames); + + addNodes(h, graph); + }, + tracker) + .then(() => { + return tf.graph.util.runAsyncTask('Detect series', 20, () => { + if (params.seriesNodeMinSize > 0) { + groupSeries( + h.root, h, seriesNames, params.seriesNodeMinSize, + params.seriesMap); + } + }, tracker); + }) + .then(() => { + return tf.graph.util.runAsyncTask('Adding edges', 30, () => { + addEdges(h, graph, seriesNames); + }, tracker); + }) + .then(() => { + return tf.graph.util.runAsyncTask( + 'Finding similar subgraphs', 30, () => { + h.templates = template.detect(h, params.verifyTemplate); + }, tracker); + }) + .then(() => { + return h; + }); +}; + +export function joinAndAggregateStats( + h: Hierarchy, stats: tf.graph.proto.StepStats) { + // Get all the possible device names. + let deviceNames = {}; + _.each(h.root.leaves(), nodeName => { + let leaf = h.node(nodeName); + if (leaf.device != null) { + deviceNames[leaf.device] = true; + } + }); + h.devices = _.keys(deviceNames); + + // Reset stats for each group node. + _.each(h.getNodeMap(), (node, nodeName) => { + if (node.isGroupNode) { + node.stats = new NodeStats(null); + (node).deviceHistogram = {}; + } + }); + + // Bubble-up the stats and device distribution from leaves to parents. + _.each(h.root.leaves(), nodeName => { + let leaf = h.node(nodeName); + let node = leaf; + while (node.parentNode != null) { + if (leaf.device != null) { + let deviceHistogram = (node.parentNode).deviceHistogram; + deviceHistogram[leaf.device] = (deviceHistogram[leaf.device] || 0) + 1; + } + if (leaf.stats != null) { + node.parentNode.stats.combine(leaf.stats); + } + node = node.parentNode; + } + }); +} + +/** + * Creates the metanodes in the hierarchical graph and assigns parent-child + * relationship between them. + */ +function addNodes(h: Hierarchy, graph: SlimGraph) { + _.each(graph.nodes, (node, nodeName) => { + let path = getHierarchicalPath(node.name); + let parent: Metanode = h.root; + + parent.depth = Math.max(path.length, parent.depth); + + // Create parent metanodes for each depth. For example if the node name + // is 'a/b/c', then create metanodes 'a' and 'a/b', where 'a/b' is a child + // of a. + for (let i = 0; i < path.length; i++) { + parent.depth = Math.max(parent.depth, path.length - i); + parent.cardinality += node.cardinality; + parent.opHistogram[node.op] = (parent.opHistogram[node.op] || 0) + 1; + if (node.device != null) { + parent.deviceHistogram[node.device] = + (parent.deviceHistogram[node.device] || 0) + 1; + } + if (i === path.length - 1) { break; } + let name = path[i]; + let child = h.node(name); + if (!child) { + child = createMetanode(name); + child.parentNode = parent; + h.setNode(name, child); + parent.metagraph.setNode(name, child); + } + parent = child; + } + // Assuming node name is 'a/b/c', assign the OpNode as a child of the + // metanode 'a/b'. + h.setNode(node.name, node); + node.parentNode = parent; + parent.metagraph.setNode(node.name, node); + + // Add each of the in-embeddings and out-embeddings in the hierarchy. + _.each(node.inEmbeddings, function(embedding) { + h.setNode(embedding.name, embedding); + embedding.parentNode = node; + }); + _.each(node.outEmbeddings, function(embedding) { + h.setNode(embedding.name, embedding); + embedding.parentNode = node; + }); + }); +}; + +/** + * For each metanode in the hierarchical graph, this method adds: + * the edges in the metagraph. These are edges between nodes + * that share the same parent. + */ +function addEdges(h: Hierarchy, graph: SlimGraph, + seriesNames: { [name: string]: string }) { + + let nodeIndex = h.getNodeMap(); + + // Ancestor paths for the source and destination nodes of an edge. These are + // reused for each edge rather than allocating new ones. It's about 10% faster + // than allocating new ones on each pass through the loop. + let sourcePath: string[] = []; + let destPath: string[] = []; + + // Insert the ancestor path for a node into the provided array, including the + // node itself. Return the index of the last node inserted (always ROOT). + let getPath = (node: Node, path: string[]): number => { + let i = 0; + while (node) { + path[i++] = node.name; + node = node.parentNode; + } + return i - 1; + }; + + _.each(graph.edges, baseEdge => { + + // Get the hierarchical paths for the source and destination of the edge. + let sourceAncestorIndex = getPath(graph.nodes[baseEdge.v], sourcePath); + let destAncestorIndex = getPath(graph.nodes[baseEdge.w], destPath); + + // If the hierarchical path cannot be found for either endpoint, then we + // cannot create the edge. This happens for example when a node has a + // control dependency on a summary node, which are embedded. + if (sourceAncestorIndex === -1 || destAncestorIndex === -1) { + return; + } + + // Find the lowest shared ancestor between source and dest by looking for + // the highest nodes that differ between their ancestor paths. + while (sourcePath[sourceAncestorIndex] === destPath[destAncestorIndex]) { + sourceAncestorIndex--; + destAncestorIndex--; + if (sourceAncestorIndex < 0 || destAncestorIndex < 0) { + // This would only occur if the two nodes were the same (a cycle in the + // graph), or if one endpoint was a strict ancestor of the other. The + // latter shouldn't happen because we rename nodes which are both + // metanodes and op nodes. E.g. 'A/B' becomes 'A/B/(B)'. + throw Error('No difference found between ancestor paths.'); + } + } + + let sharedAncestorNode = + nodeIndex[sourcePath[sourceAncestorIndex + 1]]; + let sourceAncestorName = sourcePath[sourceAncestorIndex]; + let destAncestorName = destPath[destAncestorIndex]; + + // Find or create the Metaedge which should contain this BaseEdge inside + // the shared ancestor. + let metaedge = + sharedAncestorNode.metagraph.edge(sourceAncestorName, destAncestorName); + if (!metaedge) { + metaedge = createMetaedge(sourceAncestorName, destAncestorName); + sharedAncestorNode.metagraph + .setEdge(sourceAncestorName, destAncestorName, metaedge); + } + if (!sharedAncestorNode.hasNonControlEdges && + !baseEdge.isControlDependency) { + sharedAncestorNode.hasNonControlEdges = true; + } + metaedge.addBaseEdge(baseEdge, h); + }); +}; + +/** + * Using the hierarchy template information, detect series in the provided + * metanode. For each detected series, create a new SeriesNode + * and remove series members from the metanode's metagraph and move them to + * the new series node's metagraph. + * + * @param metanode + * @param hierarchy + * @param seriesNames Map of node names to their series they are contained in. + * This should be provided empty and is populated by this method. + * @param threshold If the series has this many nodes or more, then group them + * into a series. + * @param map Map of series names to their series grouping type, if one has + * been set. + * @return A dictionary from node name to series node name that contains the + * node. + */ +function groupSeries(metanode: Metanode, hierarchy: Hierarchy, + seriesNames: { [name: string]: string }, threshold: number, + map: { [name: string]: tf.graph.SeriesGroupingType }) { + let metagraph = metanode.metagraph; + _.each(metagraph.nodes(), n => { + let child = metagraph.node(n); + if (child.type === tf.graph.NodeType.META) { + groupSeries(child, hierarchy, seriesNames, threshold, map); + } + }); + + let clusters = clusterNodes(metagraph); + let seriesDict = detectSeries(clusters, metagraph); + + // Add each series node to the graph and add its grouped children to its own + // metagraph. + _.each(seriesDict, function(seriesNode: SeriesNode, seriesName: string) { + let nodeMemberNames = seriesNode.metagraph.nodes(); + _.each(nodeMemberNames, n => { + let child = metagraph.node(n); + if (!child.owningSeries) { + child.owningSeries = seriesName; + } + }); + // If the series contains less than the threshold number of nodes and + // this series has not been adding to the series map, then set this + // series to be shown ungrouped in the map. + if (nodeMemberNames.length < threshold && !(seriesNode.name in map)) { + map[seriesNode.name] = tf.graph.SeriesGroupingType.UNGROUP; + } + // If the series is in the map as ungrouped then do not group the series. + if (seriesNode.name in map + && map[seriesNode.name] === tf.graph.SeriesGroupingType.UNGROUP) { + return; + } + hierarchy.setNode(seriesName, seriesNode); // add to the index + metagraph.setNode(seriesName, seriesNode); + _.each(nodeMemberNames, n => { + let child = metagraph.node(n); + seriesNode.metagraph.setNode(n, child); + seriesNode.parentNode = child.parentNode; + seriesNode.cardinality++; + if (child.device != null) { + seriesNode.deviceHistogram[child.device] = + (seriesNode.deviceHistogram[child.device] || 0) + 1; + } + child.parentNode = seriesNode; + seriesNames[n] = seriesName; + // Remove now-grouped node from its original parent's metagraph. + metagraph.removeNode(n); + }); + }); +}; + +/** cluster op-nodes with similar op */ +function clusterNodes(metagraph: graphlib.Graph): + {[clusterId: string]: string[]} { + let result: {[clusterId: string]: string[]} = {}; + return _.reduce(metagraph.nodes(), + (clusters: {[clusterId: string]: string[]}, n: string) => { + let child = metagraph.node(n); + if (child.type === NodeType.META) { + // skip metanodes + return clusters; + } + let template = (child).op; + if (template) { + clusters[template] = clusters[template] || []; + clusters[template].push(child.name); + } + return clusters; + }, result); +} + +/** + * For each cluster of op-nodes based op type, try to detect groupings. + * Infer series name using by trying to find pattern '' in the node + * name. + * + * @param clusters Dictionary output from clusterNodes(). + * @param metagraph + * @return A dictionary from series name => seriesNode + */ +function detectSeries(clusters: {[clusterId: string]: string[]}, + metagraph: graphlib.Graph): + {[seriesName: string]: SeriesNode} { + let seriesDict: {[seriesName: string]: SeriesNode} = {}; + _.each(clusters, function(members, clusterId: string) { + if (members.length <= 1) { return; } // isolated clusters can't make series + + /** @type {Object} A dictionary mapping seriesName to seriesInfoArray, + * which is an array that contains objects with name, id, prefix, suffix, + * and parent properties. + */ + let candidatesDict: {[seriesName: string]: SeriesNode[]} = {}; + + // Group all nodes that have the same name, with the exception of a + // number at the end of the name after an underscore, which is allowed to + // vary. + _.each(members, function(name: string) { + let isGroup = name.charAt(name.length - 1) === '*'; + let namepath = name.split('/'); + let leaf = namepath[namepath.length - 1]; + let parent = namepath.slice(0, namepath.length - 1).join('/'); + let matches = leaf.match(/^(\D*)_(\d+)$/); + + let prefix; + let id; + let suffix = ''; + if (matches) { // if found '' in the name, assign id. + prefix = matches[1]; // the front non-numeric characters + id = matches[2]; // the digits + } else { // for node without '_', make them zero-th items. + prefix = isGroup ? leaf.substr(0, leaf.length - 1) : leaf; + id = 0; + suffix = isGroup ? '*' : ''; + } + let seriesName = getSeriesNodeName(prefix, suffix, parent); + candidatesDict[seriesName] = candidatesDict[seriesName] || []; + let seriesNode = createSeriesNode(prefix, suffix, parent, +id, name); + candidatesDict[seriesName].push(seriesNode); + }); + + // In each group of nodes, group nodes in bunches that have monotonically + // increasing numbers in their names. Each of these bunches is a series. + _.each(candidatesDict, function(seriesInfoArray: SeriesNode[], seriesName) { + if (seriesInfoArray.length < 2) { + return; + } + seriesInfoArray.sort(function(a, b) { + return (+a.clusterId) - (+b.clusterId); + }); + + // Loop through the nodes sorted by its detected series number, grouping + // all nodes with monotonically-increasing series numbers. + let seriesNodes = [seriesInfoArray[0]]; + for (let index = 1; index < seriesInfoArray.length; index++) { + let nextNode = seriesInfoArray[index]; + if (nextNode.clusterId === seriesNodes[seriesNodes.length - 1].clusterId + + 1) { + seriesNodes.push(nextNode); + continue; + } + addSeriesToDict(seriesNodes, seriesDict, +clusterId, metagraph); + seriesNodes = [nextNode]; + } + addSeriesToDict(seriesNodes, seriesDict, +clusterId, metagraph); + }); + }); + return seriesDict; +} + +/** + * Add a series to the provided dictionary mapping series names to series. + * + * @param seriesNodes the nodes in the series. Contains + * name, id, prefix, suffix and parent properties of the node. + * @param seriesDict the dictionary of series + * @param clusterId ID of the template of the nodes of the series + * @param metagraph + */ +function addSeriesToDict(seriesNodes: SeriesNode[], + seriesDict: {[seriesName: string]: SeriesNode}, + clusterId: number, + metagraph: graphlib.Graph) { + if (seriesNodes.length > 1) { + let curSeriesName = getSeriesNodeName( + seriesNodes[0].prefix, seriesNodes[0].suffix, + seriesNodes[0].parent, seriesNodes[0].clusterId, + seriesNodes[seriesNodes.length - 1].clusterId); + let curSeriesNode = createSeriesNode(seriesNodes[0].prefix, + seriesNodes[0].suffix, seriesNodes[0].parent, clusterId, + curSeriesName); + _.each(seriesNodes, function(node) { + curSeriesNode.ids.push(node.clusterId); + curSeriesNode.metagraph.setNode(node.name, metagraph.node(node.name)); + }); + seriesDict[curSeriesName] = curSeriesNode; + } +} + +} // close module tf.graph.hierarchy diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/layout.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/layout.ts new file mode 100644 index 0000000000..11f41cfdd0 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/layout.ts @@ -0,0 +1,758 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +module tf.graph.layout { + +/** Set of parameters that define the look and feel of the graph. */ +export const PARAMS = { + animation: { + /** Default duration for graph animations in ms. */ + duration: 250 + }, + graph: { + /** Graph parameter for metanode. */ + meta: { + /** + * Dagre's nodesep param - number of pixels that + * separate nodes horizontally in the layout. + * + * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout + */ + nodeSep: 5, + /** + * Dagre's ranksep param - number of pixels + * between each rank in the layout. + * + * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout + */ + rankSep: 25, + /** + * Dagre's edgesep param - number of pixels that separate + * edges horizontally in the layout. + */ + edgeSep: 5, + }, + /** Graph parameter for metanode. */ + series: { + /** + * Dagre's nodesep param - number of pixels that + * separate nodes horizontally in the layout. + * + * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout + */ + nodeSep: 5, + /** + * Dagre's ranksep param - number of pixels + * between each rank in the layout. + * + * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout + */ + rankSep: 25, + /** + * Dagre's edgesep param - number of pixels that separate + * edges horizontally in the layout. + */ + edgeSep: 5 + }, + /** + * Padding is used to correctly position the graph SVG inside of its parent + * element. The padding amounts are applied using an SVG transform of X and + * Y coordinates. + */ + padding: {paddingTop: 40, paddingLeft: 20} + }, + subscene: { + meta: { + paddingTop: 10, + paddingBottom: 10, + paddingLeft: 10, + paddingRight: 10, + /** + * Used to leave room for the label on top of the highest node in + * the core graph. + */ + labelHeight: 20, + /** X-space between each extracted node and the core graph. */ + extractXOffset: 15, + /** Y-space between each extracted node. */ + extractYOffset: 20 + }, + series: { + paddingTop: 10, + paddingBottom: 10, + paddingLeft: 10, + paddingRight: 10, + labelHeight: 10 + } + }, + nodeSize: { + /** Size of meta nodes. */ + meta: { + radius: 5, + width: 60, + maxLabelWidth: 52, + /** A scale for the node's height based on number of nodes inside */ + height: d3.scaleLinear().domain([1, 200]).range([15, 60]).clamp(true), + /** The radius of the circle denoting the expand button. */ + expandButtonRadius: 3 + }, + /** Size of op nodes. */ + op: { + width: 15, + height: 6, + radius: 3, // for making annotation touching ellipse + labelOffset: -8, + maxLabelWidth: 30 + }, + /** Size of series nodes. */ + series: { + expanded: { + // For expanded series nodes, width and height will be + // computed to account for the subscene. + radius: 10, + labelOffset: 0, + }, + vertical: { + // When unexpanded, series whose underlying metagraphs contain + // one or more non-control edges will show as a vertical stack + // of ellipses. + width: 16, + height: 13, + labelOffset: -13, + }, + horizontal: { + // When unexpanded, series whose underlying metagraphs contain + // no non-control edges will show as a horizontal stack of + // ellipses. + width: 24, + height: 8, + radius: 10, // Forces annotations to center line. + labelOffset: -10, + }, + }, + /** Size of bridge nodes. */ + bridge: { + // NOTE: bridge nodes will normally be invisible, but they must + // take up some space so that the layout step leaves room for + // their edges. + width: 20, + height: 20, + radius: 2, + labelOffset: 0 + } + }, + shortcutSize: { + /** Size of shortcuts for op nodes */ + op: {width: 10, height: 4}, + /** Size of shortcuts for meta nodes */ + meta: {width: 12, height: 4, radius: 1}, + /** Size of shortcuts for series nodes */ + series: { + width: 14, + height: 4, + } + }, + annotations: { + /** Maximum possible width of the bounding box for in annotations */ + inboxWidth: 50, + /** Maximum possible width of the bounding box for out annotations */ + outboxWidth: 50, + /** X-space between the shape and each annotation-node. */ + xOffset: 10, + /** Y-space between each annotation-node. */ + yOffset: 3, + /** X-space between each annotation-node and its label. */ + labelOffset: 2, + /** Defines the max width for annotation label */ + maxLabelWidth: 120 + }, + constant: {size: {width: 4, height: 4}}, + series: { + /** Maximum number of repeated item for unexpanded series node. */ + maxStackCount: 3, + /** + * Positioning offset ratio for collapsed stack + * of parallel series (series without edges between its members). + */ + parallelStackOffsetRatio: 0.2, + /** + * Positioning offset ratio for collapsed stack + * of tower series (series with edges between its members). + */ + towerStackOffsetRatio: 0.5 + }, + minimap: { + /** The maximum width/height the minimap can have. */ + size: 150 + } +}; + +/** Calculate layout for a scene of a group node. */ +export function layoutScene(renderNodeInfo: render.RenderGroupNodeInfo): void { + // Update layout, size, and annotations of its children nodes and edges. + if (renderNodeInfo.node.isGroupNode) { + layoutChildren(renderNodeInfo); + } + + // Update position of its children nodes and edges + if (renderNodeInfo.node.type === NodeType.META) { + layoutMetanode(renderNodeInfo); + } else if (renderNodeInfo.node.type === NodeType.SERIES) { + layoutSeriesNode(renderNodeInfo); + } +}; + +/** + * Updates the total width of an unexpanded node which includes the size of its + * in and out annotations. + */ +function updateTotalWidthOfNode(renderInfo: render.RenderNodeInfo): void { + renderInfo.inboxWidth = renderInfo.inAnnotations.list.length > 0 ? + PARAMS.annotations.inboxWidth : 0; + renderInfo.outboxWidth = renderInfo.outAnnotations.list.length > 0 ? + PARAMS.annotations.outboxWidth : 0; + // Assign the width of the core box (the main shape of the node). + renderInfo.coreBox.width = renderInfo.width; + renderInfo.coreBox.height = renderInfo.height; + // TODO(jimbo): Account for font width rather than using a magic number. + let labelLength = renderInfo.node.name.length - + renderInfo.node.name.lastIndexOf(NAMESPACE_DELIM) - 1; + let charWidth = 3; // 3 pixels per character. + // Compute the total width of the node. + renderInfo.width = Math.max(renderInfo.coreBox.width + + renderInfo.inboxWidth + renderInfo.outboxWidth, + labelLength * charWidth); + +} + +/** + * Update layout, size, and annotations of its children nodes and edges. + */ +function layoutChildren(renderNodeInfo: render.RenderGroupNodeInfo): void { + let children = renderNodeInfo.coreGraph.nodes().map(n => { + return renderNodeInfo.coreGraph.node(n); + }).concat(renderNodeInfo.isolatedInExtract, + renderNodeInfo.isolatedOutExtract); + + _.each(children, childNodeInfo => { + // Set size of each child + switch (childNodeInfo.node.type) { + case NodeType.OP: + _.extend(childNodeInfo, PARAMS.nodeSize.op); + break; + case NodeType.BRIDGE: + _.extend(childNodeInfo, PARAMS.nodeSize.bridge); + break; + case NodeType.META: + if (!childNodeInfo.expanded) { + // Set fixed width and scalable height based on cardinality + _.extend(childNodeInfo, PARAMS.nodeSize.meta); + childNodeInfo.height = + PARAMS.nodeSize.meta.height(childNodeInfo.node.cardinality); + } else { + let childGroupNodeInfo = + childNodeInfo; + layoutScene(childGroupNodeInfo); // Recursively layout its subscene. + } + break; + case NodeType.SERIES: + if (childNodeInfo.expanded) { + _.extend(childNodeInfo, PARAMS.nodeSize.series.expanded); + let childGroupNodeInfo = + childNodeInfo; + layoutScene(childGroupNodeInfo); // Recursively layout its subscene. + } else { + let childGroupNodeInfo = + childNodeInfo; + let seriesParams = + childGroupNodeInfo.node.hasNonControlEdges ? + PARAMS.nodeSize.series.vertical : + PARAMS.nodeSize.series.horizontal; + _.extend(childNodeInfo, seriesParams); + } + break; + default: + throw Error('Unrecognized node type: ' + childNodeInfo.node.type); + } + // Compute total width of un-expanded nodes. Width of expanded nodes + // has already been computed. + if (!childNodeInfo.expanded) { + updateTotalWidthOfNode(childNodeInfo); + } + // Layout each child's annotations + layoutAnnotation(childNodeInfo); + }); +} + +/** + * Calculate layout for a graph using dagre + * @param graph the graph to be laid out + * @param params layout parameters + * @return width and height of the core graph + */ +function dagreLayout( + graph: graphlib.Graph, + params): {height: number, width: number} { + _.extend(graph.graph(), { + nodesep: params.nodeSep, + ranksep: params.rankSep, + edgesep: params.edgeSep + }); + let bridgeNodeNames = []; + let nonBridgeNodeNames = []; + + // Split out nodes into bridge and non-bridge nodes, and calculate the total + // width we should use for bridge nodes. + _.each(graph.nodes(), nodeName => { + let nodeInfo = graph.node(nodeName); + if (nodeInfo.node.type === NodeType.BRIDGE) { + bridgeNodeNames.push(nodeName); + } else { + nonBridgeNodeNames.push(nodeName); + } + }); + + // If there are no non-bridge nodes, then the graph has zero size. + if (!nonBridgeNodeNames.length) { + return { + width: 0, + height: 0, + }; + } + dagre.layout(graph); + + // Calculate the true bounding box of the graph by iterating over nodes and + // edges rather than accepting dagre's word for it. In particular, we should + // ignore the extra-wide bridge nodes and bridge edges, and allow for + // annotation boxes and labels. + let minX = Infinity; + let minY = Infinity; + let maxX = -Infinity; + let maxY = -Infinity; + _.each(nonBridgeNodeNames, nodeName => { + let nodeInfo = graph.node(nodeName); + let w = 0.5 * nodeInfo.width; + let x1 = nodeInfo.x - w; + let x2 = nodeInfo.x + w; + minX = x1 < minX ? x1 : minX; + maxX = x2 > maxX ? x2 : maxX; + // TODO(jimbo): Account for the height of labels above op nodes here. + let h = 0.5 * nodeInfo.height; + let y1 = nodeInfo.y - h; + let y2 = nodeInfo.y + h; + minY = y1 < minY ? y1 : minY; + maxY = y2 > maxY ? y2 : maxY; + }); + _.each(graph.edges(), edgeObj => { + let edgeInfo = graph.edge(edgeObj); + if (edgeInfo.structural) { + return; // Skip structural edges from min/max calculations. + } + + // Since the node size passed to dagre includes the in and out + // annotations, the endpoints of the edge produced by dagre may not + // point to the actual node shape (rectangle, ellipse). We correct the + // end-points by finding the intersection of a line between the + // next-to-last (next-to-first) point and the destination (source) + // rectangle. + let sourceNode = graph.node(edgeInfo.metaedge.v); + let destNode = graph.node(edgeInfo.metaedge.w); + + // Straight 3-points edges are special case, since they are curved after + // our default correction. To keep them straight, we remove the mid point + // and correct the first and the last point to be the center of the + // source and destination node respectively. + if (edgeInfo.points.length === 3 && isStraightLine(edgeInfo.points)) { + if (sourceNode != null) { + let cxSource = sourceNode.expanded ? + sourceNode.x : computeCXPositionOfNodeShape(sourceNode); + edgeInfo.points[0].x = cxSource; + } + if (destNode != null) { + let cxDest = destNode.expanded ? + destNode.x : computeCXPositionOfNodeShape(destNode); + edgeInfo.points[2].x = cxDest; + } + // Remove the middle point so the edge doesn't curve. + edgeInfo.points = [edgeInfo.points[0], edgeInfo.points[1]]; + } + // Correct the destination endpoint of the edge. + let nextToLastPoint = edgeInfo.points[edgeInfo.points.length - 2]; + // The destination node might be null if this is a bridge edge. + if (destNode != null) { + edgeInfo.points[edgeInfo.points.length - 1] = + intersectPointAndNode(nextToLastPoint, destNode); + } + // Correct the source endpoint of the edge. + let secondPoint = edgeInfo.points[1]; + // The source might be null if this is a bridge edge. + if (sourceNode != null) { + edgeInfo.points[0] = intersectPointAndNode(secondPoint, sourceNode); + } + + _.each(edgeInfo.points, (point: render.Point) => { + minX = point.x < minX ? point.x : minX; + maxX = point.x > maxX ? point.x : maxX; + minY = point.y < minY ? point.y : minY; + maxY = point.y > maxY ? point.y : maxY; + }); + }); + + // Shift all nodes and edge points to account for the left-padding amount, + // and the invisible bridge nodes. + _.each(graph.nodes(), nodeName => { + let nodeInfo = graph.node(nodeName); + nodeInfo.x -= minX; + nodeInfo.y -= minY; + }); + _.each(graph.edges(), edgeObj => { + _.each(graph.edge(edgeObj).points, (point: render.Point) => { + point.x -= minX; + point.y -= minY; + }); + }); + + return { + width: maxX - minX, + height: maxY - minY + }; +} + +/** Layout a metanode. Only called for an expanded node. */ +function layoutMetanode(renderNodeInfo: render.RenderGroupNodeInfo): void { + // First, copy params specific to meta nodes onto this render info object. + let params = PARAMS.subscene.meta; + _.extend(renderNodeInfo, params); + // Invoke dagre.layout() on the core graph and record the bounding box + // dimensions. + _.extend(renderNodeInfo.coreBox, + dagreLayout(renderNodeInfo.coreGraph, PARAMS.graph.meta)); + + // Calculate the position of nodes in isolatedInExtract relative to the + // top-left corner of inExtractBox (the bounding box for all inExtract nodes) + // and calculate the size of the inExtractBox. + let maxInExtractWidth = _.max(renderNodeInfo.isolatedInExtract, + renderNode => renderNode.width).width; + renderNodeInfo.inExtractBox.width = maxInExtractWidth != null ? + maxInExtractWidth : 0; + + renderNodeInfo.inExtractBox.height = + _.reduce(renderNodeInfo.isolatedInExtract, (height, child, i) => { + let yOffset = i > 0 ? params.extractYOffset : 0; + // use width/height here to avoid overlaps between extracts + child.x = 0; + child.y = height + yOffset + child.height / 2; + return height + yOffset + child.height; + }, 0); + + // Calculate the position of nodes in isolatedOutExtract relative to the + // top-left corner of outExtractBox (the bounding box for all outExtract + // nodes) and calculate the size of the outExtractBox. + let maxOutExtractWidth = _.max(renderNodeInfo.isolatedOutExtract, + renderNode => renderNode.width).width; + renderNodeInfo.outExtractBox.width = maxOutExtractWidth != null ? + maxOutExtractWidth : 0; + + renderNodeInfo.outExtractBox.height = + _.reduce(renderNodeInfo.isolatedOutExtract, (height, child, i) => { + let yOffset = i > 0 ? params.extractYOffset : 0; + // use width/height here to avoid overlaps between extracts + child.x = 0; + child.y = height + yOffset + child.height / 2; + return height + yOffset + child.height; + }, 0); + + // Compute the total padding between the core graph, in-extract and + // out-extract boxes. + let numParts = 0; + if (renderNodeInfo.isolatedInExtract.length > 0) { + numParts++; + } + if (renderNodeInfo.isolatedOutExtract.length > 0) { + numParts++; + } + if (renderNodeInfo.coreGraph.nodeCount() > 0) { + numParts++; + } + let offset = PARAMS.subscene.meta.extractXOffset; + let padding = numParts <= 1 ? 0 : (numParts <= 2 ? offset : 2 * offset); + + // Add the in-extract and out-extract width to the core box width. + renderNodeInfo.coreBox.width += renderNodeInfo.inExtractBox.width + + renderNodeInfo.outExtractBox.width + padding; + renderNodeInfo.coreBox.height = + params.labelHeight + + Math.max( + renderNodeInfo.inExtractBox.height, + renderNodeInfo.coreBox.height, + renderNodeInfo.outExtractBox.height + ); + // Determine the whole metanode's width (from left to right). + renderNodeInfo.width = renderNodeInfo.coreBox.width + + params.paddingLeft + params.paddingRight; + + // Determine the whole metanode's height (from top to bottom). + renderNodeInfo.height = + renderNodeInfo.paddingTop + + renderNodeInfo.coreBox.height + + renderNodeInfo.paddingBottom; +} + +/** + * Calculate layout for series node's core graph. Only called for an expanded + * series. + */ +function layoutSeriesNode(node: render.RenderGroupNodeInfo): void { + let graph = node.coreGraph; + + let params = PARAMS.subscene.series; + _.extend(node, params); + + // Layout the core. + _.extend(node.coreBox, dagreLayout(node.coreGraph, PARAMS.graph.series)); + + _.each(graph.nodes(), nodeName => { + graph.node(nodeName).excluded = false; + }); + + // Series do not have in/outExtractBox so no need to include them here. + node.width = node.coreBox.width + params.paddingLeft + params.paddingRight; + node.height = node.coreBox.height + params.paddingTop + params.paddingBottom; +} + +/** + * Calculate layout for annotations of a given node. + * This will modify positions of the given node and its annotations. + * + * @see tf.graph.render.Node and tf.graph.render.Annotation + * for description of each property of each render node. + * + */ +function layoutAnnotation(renderNodeInfo: render.RenderNodeInfo): void { + // If the render node is an expanded metanode, then its annotations will not + // be visible and we should skip the annotation calculations. + if (renderNodeInfo.expanded) { + return; + } + + let inAnnotations = renderNodeInfo.inAnnotations.list; + let outAnnotations = renderNodeInfo.outAnnotations.list; + + // Calculate size for in-annotations + _.each(inAnnotations, a => sizeAnnotation(a)); + + // Calculate size for out-annotations + _.each(outAnnotations, a => sizeAnnotation(a)); + + let params = PARAMS.annotations; + + // Calculate annotation node position (a.dx, a.dy) + // and total height for in-annotations + // After this chunk of code: + // inboxHeight = sum of annotation heights+ (annotation.length - 1 * yOffset) + let inboxHeight = _.reduce(inAnnotations, + (height, a, i) => { + let yOffset = i > 0 ? params.yOffset : 0; + a.dx = -(renderNodeInfo.coreBox.width + a.width) / 2 - params.xOffset; + a.dy = height + yOffset + a.height / 2; + return height + yOffset + a.height; + }, 0); + + _.each(inAnnotations, a => { + a.dy -= inboxHeight / 2; + + a.labelOffset = params.labelOffset; + }); + + // Calculate annotation node position (a.dx, a.dy) + // and total height for out-annotations + // After this chunk of code: + // outboxHeight = sum of annotation heights + + // (annotation.length - 1 * yOffset) + let outboxHeight = _.reduce(outAnnotations, + (height, a, i) => { + let yOffset = i > 0 ? params.yOffset : 0; + a.dx = (renderNodeInfo.coreBox.width + a.width) / 2 + params.xOffset; + a.dy = height + yOffset + a.height / 2; + return height + yOffset + a.height; + }, 0); + + _.each(outAnnotations, a => { + // adjust by (half of ) the total height + // so dy is relative to the host node's center. + a.dy -= outboxHeight / 2; + + a.labelOffset = params.labelOffset; + }); + + // Creating scales for touch point between the in-annotation edges + // and their hosts. + + let inTouchHeight = + Math.min(renderNodeInfo.height / 2 - renderNodeInfo.radius, + inboxHeight / 2); + inTouchHeight = inTouchHeight < 0 ? 0 : inTouchHeight; + + let inY = d3.scaleLinear() + .domain([0, inAnnotations.length - 1]) + .range([-inTouchHeight, inTouchHeight]); + + // Calculate annotation edge position + _.each(inAnnotations, (a, i) => { + a.points = [ + // The annotation node end + { + dx: a.dx + a.width / 2, + dy: a.dy + }, + + // The host node end + { + dx: - renderNodeInfo.coreBox.width / 2, + // only use scale if there are more than one, + // otherwise center it vertically + dy: inAnnotations.length > 1 ? inY(i) : 0 + } + ]; + }); + + // Creating scales for touch point between the out-annotation edges + // and their hosts. + let outTouchHeight = + Math.min(renderNodeInfo.height / 2 - renderNodeInfo.radius, + outboxHeight / 2); + outTouchHeight = outTouchHeight < 0 ? 0 : outTouchHeight; + let outY = d3.scaleLinear() + .domain([0, outAnnotations.length - 1]) + .range([-outTouchHeight, outTouchHeight]); + + _.each(outAnnotations, (a, i) => { + // Add point from the border of the annotation node + a.points = [ + // The host node end + { + dx: renderNodeInfo.coreBox.width / 2, + // only use scale if there are more than one, + // otherwise center it vertically + dy: outAnnotations.length > 1 ? outY(i) : 0 + }, + // The annotation node end + { + dx: a.dx - a.width / 2, + dy: a.dy + } + ]; + }); + + renderNodeInfo.height = + Math.max(renderNodeInfo.height, inboxHeight, outboxHeight); +} + +/** + * Set size of an annotation node. + */ +function sizeAnnotation(a: render.Annotation): void { + switch (a.annotationType) { + case render.AnnotationType.CONSTANT: + _.extend(a, PARAMS.constant.size); + break; + case render.AnnotationType.SHORTCUT: + if (a.node.type === NodeType.OP) { + _.extend(a, PARAMS.shortcutSize.op); + } else if (a.node.type === NodeType.META) { + _.extend(a, PARAMS.shortcutSize.meta); + } else if (a.node.type === NodeType.SERIES) { + _.extend(a, PARAMS.shortcutSize.series); + } else { + throw Error('Invalid node type: ' + a.node.type); + } + break; + case render.AnnotationType.SUMMARY: + _.extend(a, PARAMS.constant.size); + break; + } +} + +/** + * Determines the center position of the node's shape. The position depends + * on if the node has in and out-annotations. + */ +export function computeCXPositionOfNodeShape(renderInfo: render.RenderNodeInfo): + number { + if (renderInfo.expanded) { + return renderInfo.x; + } + let dx = renderInfo.inAnnotations.list.length ? renderInfo.inboxWidth : 0; + return renderInfo.x - renderInfo.width / 2 + dx + + renderInfo.coreBox.width / 2; +} + +/** Returns the angle (in degrees) between two points. */ +function angleBetweenTwoPoints(a: render.Point, b: render.Point): number { + let dx = b.x - a.x; + let dy = b.y - a.y; + return 180 * Math.atan(dy / dx) / Math.PI; +} + +/** + * Returns if a line going through the specified points is a straight line. + */ +function isStraightLine(points: render.Point[]) { + let angle = angleBetweenTwoPoints(points[0], points[1]); + for (let i = 1; i < points.length - 1; i++) { + let newAngle = angleBetweenTwoPoints(points[i], points[i + 1]); + // Have a tolerance of 1 degree. + if (Math.abs(newAngle - angle) > 1) { + return false; + } + angle = newAngle; + } + return true; +} + +/** + * Returns the intersection of a line between the provided point + * and the provided rectangle. + */ +function intersectPointAndNode( + point: render.Point, node: render.RenderNodeInfo): render.Point { + // cx and cy are the center of the rectangle. + let cx = node.expanded ? + node.x : computeCXPositionOfNodeShape(node); + let cy = node.y; + // Calculate the slope + let dx = point.x - cx; + let dy = point.y - cy; + let w = node.expanded ? node.width : node.coreBox.width; + let h = node.expanded ? node.height : node.coreBox.height; + let deltaX, deltaY; + if (Math.abs(dy) * w / 2 > Math.abs(dx) * h / 2) { + // The intersection is above or below the rectangle. + if (dy < 0) { + h = -h; + } + deltaX = dy === 0 ? 0 : h / 2 * dx / dy; + deltaY = h / 2; + } else { + // The intersection is left or right of the rectangle. + if (dx < 0) { + w = -w; + } + deltaX = w / 2; + deltaY = dx === 0 ? 0 : w / 2 * dy / dx; + } + return {x: cx + deltaX, y: cy + deltaY}; +} + +} // close module diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/minimap.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/minimap.ts new file mode 100644 index 0000000000..9a07323a1d --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/minimap.ts @@ -0,0 +1,327 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +module tf.scene { + +/** Show minimap when the viewpoint area is less than X% of the whole area. */ +const FRAC_VIEWPOINT_AREA: number = 0.8; + +export class Minimap { + /** The minimap container. */ + private minimap: HTMLElement; + /** The canvas used for drawing the mini version of the svg. */ + private canvas: HTMLCanvasElement; + /** A buffer canvas used for temporary drawing to avoid flickering. */ + private canvasBuffer: HTMLCanvasElement; + private download: HTMLLinkElement; + private downloadCanvas: HTMLCanvasElement; + + /** The minimap svg used for holding the viewpoint rectangle. */ + private minimapSvg: SVGSVGElement; + /** The rectangle showing the current viewpoint. */ + private viewpoint: SVGRectElement; + /** + * The scale factor for the minimap. The factor is determined automatically + * so that the minimap doesn't violate the maximum width/height specified + * in the constructor. The minimap maintains the same aspect ratio as the + * original svg. + */ + private scaleMinimap: number; + /** The main svg element. */ + private svg: SVGSVGElement; + /** The svg group used for panning and zooming the main svg. */ + private zoomG: SVGGElement; + /** The zoom behavior of the main svg. */ + private mainZoom: d3.ZoomBehavior; + /** The maximum width and height for the minimap. */ + private maxWandH: number; + /** The last translation vector used in the main svg. */ + private translate: [number, number]; + /** The last scaling factor used in the main svg. */ + private scaleMain: number; + /** The coordinates of the viewpoint rectangle. */ + private viewpointCoord: {x: number, y: number}; + /** The current size of the minimap */ + private minimapSize: {width: number, height: number}; + /** Padding (px) due to the main labels of the graph. */ + private labelPadding: number; + /** + * Constructs a new minimap. + * + * @param svg The main svg element. + * @param zoomG The svg group used for panning and zooming the main svg. + * @param mainZoom The main zoom behavior. + * @param minimap The minimap container. + * @param maxWandH The maximum width/height for the minimap. + * @param labelPadding Padding in pixels due to the main graph labels. + */ + constructor(svg: SVGSVGElement, zoomG: SVGGElement, + mainZoom: d3.ZoomBehavior, minimap: HTMLElement, + maxWandH: number, labelPadding: number) { + this.svg = svg; + this.labelPadding = labelPadding; + this.zoomG = zoomG; + this.mainZoom = mainZoom; + this.maxWandH = maxWandH; + let $minimap = d3.select(minimap); + // The minimap will have 2 main components: the canvas showing the content + // and an svg showing a rectangle of the currently zoomed/panned viewpoint. + let $minimapSvg = $minimap.select('svg'); + + // Make the viewpoint rectangle draggable. + let $viewpoint = $minimapSvg.select('rect'); + let dragmove = (d) => { + this.viewpointCoord.x = (d3.event).x; + this.viewpointCoord.y = (d3.event).y; + this.updateViewpoint(); + }; + this.viewpointCoord = {x: 0, y: 0}; + let drag = d3.drag().subject(Object).on('drag', dragmove); + $viewpoint.datum(this.viewpointCoord as any).call(drag); + + // Make the minimap clickable. + $minimapSvg.on('click', () => { + if ((d3.event).defaultPrevented) { + // This click was part of a drag event, so suppress it. + return; + } + // Update the coordinates of the viewpoint. + let width = Number($viewpoint.attr('width')); + let height = Number($viewpoint.attr('height')); + let clickCoords = d3.mouse($minimapSvg.node() as any); + this.viewpointCoord.x = clickCoords[0] - width / 2; + this.viewpointCoord.y = clickCoords[1] - height / 2; + this.updateViewpoint(); + }); + this.viewpoint = $viewpoint.node(); + this.minimapSvg = $minimapSvg.node(); + this.minimap = minimap; + this.canvas = $minimap.select('canvas.first').node(); + this.canvasBuffer = + $minimap.select('canvas.second').node(); + this.downloadCanvas = + $minimap.select('canvas.download').node(); + d3.select(this.downloadCanvas).style('display', 'none'); + this.update(); + } + + /** + * Updates the position and the size of the viewpoint rectangle. + * It also notifies the main svg about the new panned position. + */ + private updateViewpoint(): void { + // Update the coordinates of the viewpoint rectangle. + d3.select(this.viewpoint) + .attr('x', this.viewpointCoord.x) + .attr('y', this.viewpointCoord.y); + // Update the translation vector of the main svg to reflect the + // new viewpoint. + let mainX = - this.viewpointCoord.x * this.scaleMain / this.scaleMinimap; + let mainY = - this.viewpointCoord.y * this.scaleMain / this.scaleMinimap; + this.mainZoom.translateBy(d3.select(this.zoomG), mainX, mainY); + + } + + /** + * Redraws the minimap. Should be called whenever the main svg + * was updated (e.g. when a node was expanded). + */ + update(): void { + let sceneSize = null; + try { + // Get the size of the entire scene. + sceneSize = this.zoomG.getBBox(); + if (sceneSize.width === 0) { + // There is no scene anymore. We have been detached from the dom. + return; + } + } catch (e) { + // Firefox produced NS_ERROR_FAILURE if we have been + // detached from the dom. + return; + } + let $download = d3.select('#graphdownload'); + this.download = $download.node(); + $download.on('click', d => { + this.download.href = this.downloadCanvas.toDataURL('image/png'); + }); + + let $svg = d3.select(this.svg); + // Read all the style rules in the document and embed them into the svg. + // The svg needs to be self contained, i.e. all the style rules need to be + // embedded so the canvas output matches the origin. + let stylesText = ''; + for (let k = 0; k < document.styleSheets.length; k++) { + try { + let cssRules = (document.styleSheets[k]).cssRules || + (document.styleSheets[k]).rules; + if (cssRules == null) { + continue; + } + for (let i = 0; i < cssRules.length; i++) { + // Remove tf-* selectors from the styles. + stylesText += + cssRules[i].cssText.replace(/ ?tf-[\w-]+ ?/g, '') + '\n'; + } + } catch (e) { + if (e.name !== 'SecurityError') { + throw e; + } + } + } + + // Temporarily add the css rules to the main svg. + let svgStyle = $svg.append('style'); + svgStyle.text(stylesText); + + // Temporarily remove the zoom/pan transform from the main svg since we + // want the minimap to show a zoomed-out and centered view. + let $zoomG = d3.select(this.zoomG); + let zoomTransform = $zoomG.attr('transform'); + $zoomG.attr('transform', null); + + // Since we add padding, account for that here. + sceneSize.height += this.labelPadding * 2; + sceneSize.width += this.labelPadding * 2; + + // Temporarily assign an explicit width/height to the main svg, since + // it doesn't have one (uses flex-box), but we need it for the canvas + // to work. + $svg + .attr('width', sceneSize.width) + .attr('height', sceneSize.height); + + // Since the content inside the svg changed (e.g. a node was expanded), + // the aspect ratio have also changed. Thus, we need to update the scale + // factor of the minimap. The scale factor is determined such that both + // the width and height of the minimap are <= maximum specified w/h. + this.scaleMinimap = + this.maxWandH / Math.max(sceneSize.width, sceneSize.height); + + this.minimapSize = { + width: sceneSize.width * this.scaleMinimap, + height: sceneSize.height * this.scaleMinimap + }; + + // Update the size of the minimap's svg, the buffer canvas and the + // viewpoint rect. + d3.select(this.minimapSvg).attr(this.minimapSize); + d3.select(this.canvasBuffer).attr(this.minimapSize); + + // Download canvas width and height are multiples of the style width and + // height in order to increase pixel density of the PNG for clarity. + d3.select(this.downloadCanvas).style( + { width: sceneSize.width, height: sceneSize.height }); + d3.select(this.downloadCanvas).attr( + { width: sceneSize.width * 3, height: sceneSize.height * 3 }); + + if (this.translate != null && this.zoom != null) { + // Update the viewpoint rectangle shape since the aspect ratio of the + // map has changed. + requestAnimationFrame(() => this.zoom()); + } + + // Serialize the main svg to a string which will be used as the rendering + // content for the canvas. + let svgXml = (new XMLSerializer()).serializeToString(this.svg); + + // Now that the svg is serialized for rendering, remove the temporarily + // assigned styles, explicit width and height and bring back the pan/zoom + // transform. + svgStyle.remove(); + $svg.attr('width', null).attr('height', null); + + $zoomG.attr('transform', zoomTransform); + let image = new Image(); + image.onload = () => { + // Draw the svg content onto the buffer canvas. + let context = this.canvasBuffer.getContext('2d'); + context.clearRect(0, 0, this.canvasBuffer.width, + this.canvasBuffer.height); + context.drawImage(image, 0, 0, + this.minimapSize.width, this.minimapSize.height); + requestAnimationFrame(() => { + // Hide the old canvas and show the new buffer canvas. + d3.select(this.canvasBuffer).style('display', null); + d3.select(this.canvas).style('display', 'none'); + // Swap the two canvases. + [this.canvas, this.canvasBuffer] = [this.canvasBuffer, this.canvas]; + }); + let downloadContext = this.downloadCanvas.getContext('2d'); + downloadContext.clearRect(0, 0, this.downloadCanvas.width, + this.downloadCanvas.height); + downloadContext.drawImage(image, 0, 0, + this.downloadCanvas.width, this.downloadCanvas.height); + }; + image.onerror = () => { + let blob = new Blob([svgXml], {type: 'image/svg+xml;charset=utf-8'}); + image.src = URL.createObjectURL(blob); + }; + image.src = + 'data:image/svg+xml;charset=utf-8,' + encodeURIComponent(svgXml); + } + + /** + * Handles changes in zooming/panning. Should be called from the main svg + * to notify that a zoom/pan was performed and this minimap will update it's + * viewpoint rectangle. + * + * @param translate The translate vector, or none to use the last used one. + * @param scale The scaling factor, or none to use the last used one. + */ + zoom(transform?: d3.ZoomTransform): void { + if (this.scaleMinimap == null) { + // Scene is not ready yet. + return; + } + // Update the new translate and scale params, only if specified. + if (transform) { + this.translate = [transform.x, transform.y]; + this.scaleMain = transform.k; + } + + // Update the location of the viewpoint rectangle. + let svgRect = this.svg.getBoundingClientRect(); + let $viewpoint = d3.select(this.viewpoint); + this.viewpointCoord.x = -this.translate[0] * this.scaleMinimap / + this.scaleMain; + this.viewpointCoord.y = -this.translate[1] * this.scaleMinimap / + this.scaleMain; + let viewpointWidth = svgRect.width * this.scaleMinimap / this.scaleMain; + let viewpointHeight = svgRect.height * this.scaleMinimap / this.scaleMain; + $viewpoint + .attr('x', this.viewpointCoord.x) + .attr('y', this.viewpointCoord.y) + .attr('width', viewpointWidth) + .attr('height', viewpointHeight); + // Show/hide the minimap depending on the viewpoint area as fraction of the + // whole minimap. + let mapWidth = this.minimapSize.width; + let mapHeight = this.minimapSize.height; + let x = this.viewpointCoord.x; + let y = this.viewpointCoord.y; + let w = Math.min(Math.max(0, x + viewpointWidth), mapWidth) - + Math.min(Math.max(0, x), mapWidth); + let h = Math.min(Math.max(0, y + viewpointHeight), mapHeight) - + Math.min(Math.max(0, y), mapHeight); + let fracIntersect = (w * h) / (mapWidth * mapHeight); + if (fracIntersect < FRAC_VIEWPOINT_AREA) { + this.minimap.classList.remove('hidden'); + } else { + this.minimap.classList.add('hidden'); + } + } +} + +} // close module tf.scene diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/node.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/node.ts new file mode 100644 index 0000000000..e66818f4c8 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/node.ts @@ -0,0 +1,1072 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +module tf.graph.scene.node { + import RenderNodeInfo = tf.graph.render.RenderNodeInfo; + /** + * Select or Create a 'g.nodes' group to a given sceneGroup + * and builds a number of 'g.node' groups inside the group. + * + * Structure Pattern: + * + * + * + * + * ... + * + * + * ... + * + * + * + * + * node name + * + * + * + * + * ... + * + * + * + * @param sceneGroup selection of the container + * @param nodeData array of render node information to map + * @param sceneElement polymer element + * @return selection of the created nodeGroups + */ + export function buildGroup( + sceneGroup, nodeData: render.RenderNodeInfo[], sceneElement) { + let container = + scene.selectOrCreateChild(sceneGroup, 'g', Class.Node.CONTAINER); + // Select all children and join with data. + // (Note that all children of g.nodes are g.node) + let nodeGroups = + (container as any) + .selectAll('g') + .data(nodeData, (d) => { + // make sure that we don't have to swap shape type + return d.node.name + ':' + d.node.type; + }); + + // ENTER + nodeGroups.enter() + .append('g') + .attr('data-name', d => { return d.node.name; }) + .each(function(d) { + let nodeGroup = d3.select(this); + // index node group for quick stylizing + sceneElement.addNodeGroup(d.node.name, nodeGroup); + }) + .merge(nodeGroups) + // ENTER + UPDATE + .attr('class', d => { return Class.Node.GROUP + ' ' + nodeClass(d); }) + .each(function(d) { + let nodeGroup = d3.select(this); + // Add g.in-annotations (always add -- to keep layer order + // consistent.) + let inAnnotationBox = + scene.selectOrCreateChild(nodeGroup, 'g', Class.Annotation.INBOX); + annotation.buildGroup( + inAnnotationBox, d.inAnnotations, d, sceneElement); + + // Add g.out-annotations (always add -- to keep layer order + // consistent.) + let outAnnotationBox = scene.selectOrCreateChild( + nodeGroup, 'g', Class.Annotation.OUTBOX); + annotation.buildGroup( + outAnnotationBox, d.outAnnotations, d, sceneElement); + + // Build .shape first (background of the node). + let shape = buildShape(nodeGroup, d, Class.Node.SHAPE); + if (d.node.isGroupNode) { + addButton(shape, d, sceneElement); + } + addInteraction(shape, d, sceneElement); + + // Build subscene on the top. + subsceneBuild(nodeGroup, d, sceneElement); + + // Build label last. Should be on top of everything else. + let label = labelBuild(nodeGroup, d, sceneElement); + // Do not add interaction to metanode labels as they live inside the + // metanode shape which already has the same interactions. + addInteraction(label, d, sceneElement, d.node.type === NodeType.META); + + stylize(nodeGroup, d, sceneElement); + position(nodeGroup, d); + }); + + // EXIT + nodeGroups.exit() + .each(function(d) { + // remove all indices on remove + sceneElement.removeNodeGroup(d.node.name); + + let nodeGroup = d3.select(this); + if (d.inAnnotations.list.length > 0) { + nodeGroup.select('.' + Class.Annotation.INBOX) + .selectAll('.' + Class.Annotation.GROUP) + .each(a => { sceneElement.removeAnnotationGroup(a, d); }); + } + if (d.outAnnotations.list.length > 0) { + nodeGroup.select('.' + Class.Annotation.OUTBOX) + .selectAll('.' + Class.Annotation.GROUP) + .each(a => { sceneElement.removeAnnotationGroup(a, d); }); + } + }) + .remove(); + return nodeGroups; +}; + +/** + * Update or remove the subscene of a render group node depending on whether it + * is a expanded. If the node is not a group node, this method has no effect. + * + * @param nodeGroup selection of the container + * @param renderNodeInfo the render information for the node. + * @param sceneElement polymer element. + * @return Selection of the subscene group, or null if node group does not have + * a subscene. Op nodes, bridge nodes and unexpanded group nodes will + * not have a subscene. + */ +function subsceneBuild(nodeGroup, + renderNodeInfo: render.RenderGroupNodeInfo, sceneElement) { + if (renderNodeInfo.node.isGroupNode) { + if (renderNodeInfo.expanded) { + // Recursively build the subscene. + return scene.buildGroup(nodeGroup, renderNodeInfo, sceneElement, + Class.Subscene.GROUP); + } + // Clean out existing subscene if the node is not expanded. + scene.selectChild(nodeGroup, 'g', Class.Subscene.GROUP).remove(); + } + return null; +}; + +/** + * Translate the subscene of the given node group + */ +function subscenePosition(nodeGroup, d: render.RenderNodeInfo) { + let x0 = d.x - d.width / 2.0 + d.paddingLeft; + let y0 = d.y - d.height / 2.0 + d.paddingTop; + + let subscene = scene.selectChild(nodeGroup, 'g', Class.Subscene.GROUP); + scene.translate(subscene, x0, y0); +}; + +/** + * Add an expand/collapse button to a group node + * + * @param selection The group node selection. + * @param d Info about the node being rendered. + * @param sceneElement polymer element. + */ +function addButton(selection, d: render.RenderNodeInfo, sceneElement) { + let group = + scene.selectOrCreateChild(selection, 'g', Class.Node.BUTTON_CONTAINER); + scene.selectOrCreateChild(group, 'circle', Class.Node.BUTTON_CIRCLE); + scene.selectOrCreateChild(group, 'path', Class.Node.EXPAND_BUTTON) + .attr('d', 'M0,-2.2 V2.2 M-2.2,0 H2.2'); + scene.selectOrCreateChild(group, 'path', Class.Node.COLLAPSE_BUTTON) + .attr('d', 'M-2.2,0 H2.2'); + (group as any).on('click', (d: any) => { + // Stop this event's propagation so that it isn't also considered a + // node-select. + (d3.event).stopPropagation(); + sceneElement.fire('node-toggle-expand', {name: d.node.name}); + }); + scene.positionButton(group, d); +}; + +/** + * Fire node-* events when the selection is interacted. + * + * @param disableInteraction When true, have the provided selection + * ignore all pointer events. Used for text labels inside of metanodes, which + * don't need interaction as their surrounding shape has interaction, and if + * given interaction would cause conflicts with the expand/collapse button. + */ +function addInteraction(selection, d: render.RenderNodeInfo, + sceneElement, disableInteraction?: boolean) { + if (disableInteraction) { + selection.attr('pointer-events', 'none'); + return; + } + + let contextMenuFunction = contextmenu.getMenu( + getContextMenu(d.node, sceneElement)); + selection + .on('dblclick', + d => { + sceneElement.fire('node-toggle-expand', {name: d.node.name}); + }) + .on('mouseover', + d => { + // don't send mouseover over expanded group, + // otherwise it is causing too much glitches + if (sceneElement.isNodeExpanded(d)) { + return; + } + + sceneElement.fire('node-highlight', {name: d.node.name}); + }) + .on('mouseout', + d => { + // don't send mouseover over expanded group, + // otherwise it is causing too much glitches + if (sceneElement.isNodeExpanded(d)) { + return; + } + + sceneElement.fire('node-unhighlight', {name: d.node.name}); + }) + .on('click', + d => { + // Stop this event's propagation so that it isn't also considered + // a graph-select. + (d3.event).stopPropagation(); + sceneElement.fire('node-select', {name: d.node.name}); + }) + .on('contextmenu', (d, i) => { + sceneElement.fire('node-select', {name: d.node.name}); + contextMenuFunction.call(d, i); + }); +}; + +/** + * Returns the d3 context menu specification for the provided node. + */ +export function getContextMenu(node: Node, sceneElement) { + let menu = [{ + title: (d): string => { + return getIncludeNodeButtonString(node.include); + }, + action: (elm, d, i) => { + sceneElement.fire('node-toggle-extract', {name: node.name}); + } + }]; + if (canBeInSeries(node)) { + menu.push({ + title: d => { return getGroupSettingLabel(node); }, + action: (elm, d, i) => { + sceneElement.fire( + 'node-toggle-seriesgroup', {name: getSeriesName(node)}); + } + }); + } + return menu; +} + +/** Returns if a node can be part of a grouped series */ +export function canBeInSeries(node: Node) { + return getSeriesName(node) !== null; +} + +/** + * Returns the name of the possible grouped series containing this node. + * Returns null if the node cannot be part of a grouped series of nodes. + */ +export function getSeriesName(node: Node) { + if (!node) { + return null; + } + if (node.type === NodeType.SERIES) { + return node.name; + } + if (node.type === NodeType.OP) { + let op = node; + return op.owningSeries; + } + return null; +} + +/** + * Returns the SeriesNode that represents the series that the provided node + * is contained in (or itself if the provided node is itself a SeriesNode). + * Returns null if the node is not rendered as part of a series. + */ +function getContainingSeries(node: Node) { + let s: SeriesNode = null; + if (!node) { + return null; + } else if (node.type === NodeType.SERIES) { + s = node; + } else if (node.parentNode && node.parentNode.type === NodeType.SERIES) { + s = node.parentNode; + } + return s; +} + +/** + * Returns the label for a button to toggle the group setting of the provided + * node. + */ +export function getGroupSettingLabel(node: Node) { + return tf.graph.getGroupSeriesNodeButtonString( + getContainingSeries(node) !== null ? tf.graph.SeriesGroupingType.GROUP : + tf.graph.SeriesGroupingType.UNGROUP); +} + +/** + * Append svg text for label and assign data. + * @param nodeGroup + * @param renderNodeInfo The render node information for the label. + * @param sceneElement polymer element. + */ +function labelBuild(nodeGroup, renderNodeInfo: render.RenderNodeInfo, + sceneElement) { + let namePath = renderNodeInfo.node.name.split('/'); + let text = namePath[namePath.length - 1]; + + // Truncate long labels for unexpanded Metanodes. + let useFontScale = renderNodeInfo.node.type === NodeType.META && + !renderNodeInfo.expanded; + + let label = scene.selectOrCreateChild(nodeGroup, 'text', Class.Node.LABEL); + + // Make sure the label is visually on top among its siblings. + let labelNode = label.node(); + labelNode.parentNode.appendChild(labelNode); + + label.attr('dy', '.35em').attr('text-anchor', 'middle'); + if (useFontScale) { + if (text.length > sceneElement.maxMetanodeLabelLength) { + text = text.substr(0, sceneElement.maxMetanodeLabelLength - 2) + '...'; + } + let scale = getLabelFontScale(sceneElement); + label.attr('font-size', scale(text.length) + 'px'); + } + + let txtElement = >label.text(text); + enforceLabelWidth(txtElement, renderNodeInfo.node.type, renderNodeInfo); + return label; +} +/** + * This function shortens text which would exceed the maximum pixel width of + * a label. + * + * @param txtElementSelection The text element containing the label's text as d3 + * selection. + * @param nodeType The type of the node the label belongs to. If the node is + * an annotation, the value is -1. Label widths are defined in + * layout.PARAMS.nodeSize.{meta|op|...}.maxLabelWidth for nodes and + * layout.PARAMS.annotations.labelWidth for annotations. + * @param renderNodeInfo The render information about the node, required to + * determine whether META nodes are collapsed or expanded. + */ +export function enforceLabelWidth( + txtElementSelection: d3.Selection, nodeType: NodeType | number, + renderNodeInfo?: render.RenderNodeInfo) { + // Get text element itself and its on-screen width. + let txtNode = txtElementSelection.node(); + let computedTxtLength = txtNode.getComputedTextLength(); + let labelContent = txtNode.textContent; + + // Get maximum length from settings. + let maxLength = null; + switch (nodeType) { + case NodeType.META: + if (renderNodeInfo && !renderNodeInfo.expanded) { // Only trim text if + // node expanded. + maxLength = layout.PARAMS.nodeSize.meta.maxLabelWidth; + } + break; + + case NodeType.OP: + maxLength = layout.PARAMS.nodeSize.op.maxLabelWidth; + break; + + case -1: + maxLength = layout.PARAMS.annotations.maxLabelWidth; + break; + + default: + break; + } + + // Return if no max length provided for node type, or current label length is + // less than or equal to the provided length limit. + if (maxLength === null || computedTxtLength <= maxLength) { + return; + } + + // Find the index of the character which exceeds the width. + // getSubStringLength performs far better than getComputedTextLength, and + // results in a 3x speed-up on average. + let index = 1; + while (txtNode.getSubStringLength(0, index) < maxLength) { + index++; + } + + // Shorten the label starting at the string length known to be one + // character above max pixel length. + // When shortened the original label's substring is concatenated with + // '...', baseText contains the substring not including the '...'. + let baseText = txtNode.textContent.substr(0, index); + do { + baseText = baseText.substr(0, baseText.length - 1); + + // Recompute text length. + txtNode.textContent = baseText + '...'; + computedTxtLength = txtNode.getComputedTextLength(); + } while (computedTxtLength > maxLength && baseText.length > 0); + + // Add tooltip with full name and return. + return txtElementSelection.append('title').text(labelContent); +} + +/** + * d3 scale used for sizing font of labels, used by labelBuild, + * initialized once by getLabelFontScale. + */ +let fontScale = null; +function getLabelFontScale(sceneElement) { + if (!fontScale) { + fontScale = d3.scaleLinear() + .domain([sceneElement.maxMetanodeLabelLengthLargeFont, + sceneElement.maxMetanodeLabelLength]) + .range([sceneElement.maxMetanodeLabelLengthFontSize, + sceneElement.minMetanodeLabelLengthFontSize]).clamp(true); + } + return fontScale; +} + +/** + * Set label position of a given node group + */ +function labelPosition(nodeGroup, cx: number, cy: number, + yOffset: number) { + scene.selectChild(nodeGroup, 'text', Class.Node.LABEL) + .transition() + .attr('x', cx) + .attr('y', cy + yOffset); +}; + +/** + * Select or append/insert shape for a node and assign renderNode + * as the shape's data. + * + * @param nodeGroup + * @param d Render node information. + * @param nodeClass class for the element. + * @return Selection of the shape. + */ +export function buildShape(nodeGroup, d, nodeClass: string) { + // Create a group to house the underlying visual elements. + let shapeGroup = scene.selectOrCreateChild(nodeGroup, 'g', nodeClass); + // TODO(jimbo): DOM structure should be templated in HTML somewhere, not JS. + switch (d.node.type) { + case NodeType.OP: + scene.selectOrCreateChild(shapeGroup, 'ellipse', Class.Node.COLOR_TARGET); + break; + case NodeType.SERIES: + // Choose the correct stamp to use to represent this series. + let stampType = 'annotation'; + let groupNodeInfo = d; + if (groupNodeInfo.coreGraph) { + stampType = + groupNodeInfo.node.hasNonControlEdges ? 'vertical' : 'horizontal'; + } + let classList = [Class.Node.COLOR_TARGET]; + if (groupNodeInfo.isFadedOut) { + classList.push('faded-ellipse'); + } + scene.selectOrCreateChild(shapeGroup, 'use', classList) + .attr('xlink:href', '#op-series-' + stampType + '-stamp'); + scene.selectOrCreateChild(shapeGroup, 'rect', Class.Node.COLOR_TARGET) + .attr('rx', d.radius).attr('ry', d.radius); + break; + case NodeType.BRIDGE: + scene.selectOrCreateChild(shapeGroup, 'rect', Class.Node.COLOR_TARGET) + .attr('rx', d.radius).attr('ry', d.radius); + break; + case NodeType.META: + scene.selectOrCreateChild(shapeGroup, 'rect', Class.Node.COLOR_TARGET) + .attr('rx', d.radius).attr('ry', d.radius); + break; + default: + throw Error('Unrecognized node type: ' + d.node.type); + } + return shapeGroup; +}; + +export function nodeClass(d: render.RenderNodeInfo) { + switch (d.node.type) { + case NodeType.OP: + return Class.OPNODE; + case NodeType.META: + return Class.METANODE; + case NodeType.SERIES: + return Class.SERIESNODE; + case NodeType.BRIDGE: + return Class.BRIDGENODE; + case NodeType.ELLIPSIS: + return Class.ELLIPSISNODE; + }; + throw Error('Unrecognized node type: ' + d.node.type); +}; + +/** Modify node and its subscene and its label's positional attributes */ +function position(nodeGroup, d: render.RenderNodeInfo) { + let shapeGroup = scene.selectChild(nodeGroup, 'g', Class.Node.SHAPE); + let cx = layout.computeCXPositionOfNodeShape(d); + switch (d.node.type) { + case NodeType.OP: { + // position shape + let shape = scene.selectChild(shapeGroup, 'ellipse'); + scene.positionEllipse(shape, cx, d.y, d.coreBox.width, d.coreBox.height); + labelPosition(nodeGroup, cx, d.y, d.labelOffset); + break; + } + case NodeType.META: { + // position shape + let shape = scene.selectChild(shapeGroup, 'rect'); + if (d.expanded) { + scene.positionRect(shape, d.x, d.y, d.width, d.height); + subscenePosition(nodeGroup, d); + // put label on top + labelPosition(nodeGroup, cx, d.y, + - d.height / 2 + d.labelHeight / 2); + } else { + scene.positionRect(shape, cx, d.y, d.coreBox.width, d.coreBox.height); + labelPosition(nodeGroup, cx, d.y, 0); + } + break; + } + case NodeType.SERIES: { + let shape = scene.selectChild(shapeGroup, 'use'); + if (d.expanded) { + scene.positionRect(shape, d.x, d.y, d.width, d.height); + subscenePosition(nodeGroup, d); + // put label on top + labelPosition(nodeGroup, cx, d.y, + - d.height / 2 + d.labelHeight / 2); + } else { + scene.positionRect(shape, cx, d.y, d.coreBox.width, d.coreBox.height); + labelPosition(nodeGroup, cx, d.y, d.labelOffset); + } + break; + } + case NodeType.BRIDGE: { + // position shape + // NOTE: In reality, these will not be visible, but it helps to put them + // in the correct position for debugging purposes. + let shape = scene.selectChild(shapeGroup, 'rect'); + scene.positionRect(shape, d.x, d.y, d.width, d.height); + break; + } + default: { throw Error('Unrecognized node type: ' + d.node.type); } + } +}; + +/** Enum specifying the options to color nodes by */ +export enum ColorBy {STRUCTURE, DEVICE, XLA_CLUSTER, COMPUTE_TIME, MEMORY} +; + +/** + * Returns the fill color for the node given its state and the 'color by' + * option. + */ +export function getFillForNode(templateIndex, colorBy, + renderInfo: render.RenderNodeInfo, isExpanded: boolean): string { + let colorParams = render.MetanodeColors; + switch (colorBy) { + case ColorBy.STRUCTURE: + if (renderInfo.node.type === NodeType.META) { + let tid = (renderInfo.node).templateId; + return tid === null ? + colorParams.UNKNOWN : + colorParams.STRUCTURE_PALETTE(templateIndex(tid), isExpanded); + } else if (renderInfo.node.type === NodeType.SERIES) { + // If expanded, we're showing the background rect, which we want to + // appear gray. Otherwise we're showing a stack of ellipses which we + // want to show white. + return isExpanded ? colorParams.EXPANDED_COLOR : 'white'; + } else if (renderInfo.node.type === NodeType.BRIDGE) { + return renderInfo.structural ? + '#f0e' : + (renderInfo.node).inbound ? '#0ef' : '#fe0'; + } else { + // Op nodes are white. + return 'white'; + } + case ColorBy.DEVICE: + if (renderInfo.deviceColors == null) { + // Return the hue for unknown device. + return colorParams.UNKNOWN; + } + let id = renderInfo.node.name; + let escapedId = tf.graph.util.escapeQuerySelector(id); + let gradientDefs = d3.select('svg#svg defs #linearGradients'); + let linearGradient = gradientDefs.select('linearGradient#' + escapedId); + // If the linear gradient is not there yet, create it. + if (linearGradient.size() === 0) { + linearGradient = gradientDefs.append('linearGradient').attr('id', id); + // Re-create the stops of the linear gradient. + linearGradient.selectAll('*').remove(); + let cumulativeProportion = 0; + // For each device, create a stop using the proportion of that device. + _.each(renderInfo.deviceColors, d => { + let color = d.color; + linearGradient.append('stop') + .attr('offset', cumulativeProportion) + .attr('stop-color', color); + linearGradient.append('stop') + .attr('offset', cumulativeProportion + d.proportion) + .attr('stop-color', color); + cumulativeProportion += d.proportion; + }); + } + return isExpanded ? colorParams.EXPANDED_COLOR : `url(#${escapedId})`; + case ColorBy.XLA_CLUSTER: + return isExpanded ? colorParams.EXPANDED_COLOR : + renderInfo.xlaClusterColor || colorParams.UNKNOWN; + case ColorBy.COMPUTE_TIME: + return isExpanded ? + colorParams.EXPANDED_COLOR : renderInfo.computeTimeColor || + colorParams.UNKNOWN; + case ColorBy.MEMORY: + return isExpanded ? + colorParams.EXPANDED_COLOR : renderInfo.memoryColor || + colorParams.UNKNOWN; + default: + throw new Error('Unknown case to color nodes by'); + } +} + +/** + * Modify node style by toggling class and assign attributes (only for things + * that can't be done in css). + */ +export function stylize(nodeGroup, renderInfo: render.RenderNodeInfo, + sceneElement, nodeClass?) { + nodeClass = nodeClass || Class.Node.SHAPE; + let isHighlighted = sceneElement.isNodeHighlighted(renderInfo.node.name); + let isSelected = sceneElement.isNodeSelected(renderInfo.node.name); + let isExtract = renderInfo.isInExtract || renderInfo.isOutExtract; + let isExpanded = renderInfo.expanded; + let isFadedOut = renderInfo.isFadedOut; + nodeGroup.classed('highlighted', isHighlighted); + nodeGroup.classed('selected', isSelected); + nodeGroup.classed('extract', isExtract); + nodeGroup.classed('expanded', isExpanded); + nodeGroup.classed('faded', isFadedOut); + + // Main node always exists here and it will be reached before subscene, + // so d3 selection is fine here. + let node = nodeGroup.select('.' + nodeClass + ' .' + Class.Node.COLOR_TARGET); + let fillColor = getFillForNode(sceneElement.templateIndex, + ColorBy[sceneElement.colorBy.toUpperCase()], + renderInfo, isExpanded); + node.style('fill', fillColor); + + // Choose outline to be darker version of node color if the node is a single + // color and is not selected. + node.style('stroke', isSelected ? null : getStrokeForFill(fillColor)); +}; + +/** + * Given a node's fill color/gradient, determine the stroke for the node. + */ +export function getStrokeForFill(fill: string) { + // If node is colored by a gradient, then use a dark gray outline. + return fill.substring(0, 3) === 'url' ? + render.MetanodeColors.GRADIENT_OUTLINE : + d3.rgb(fill).darker().toString(); +} + +/** + * Finds selected node and highlights all nodes which are providing direct + * or indirect input to the node and all edges connecting these nodes + * together and to the selected node. + * + * @param renderGraphInfo Information on the rendered state of the graph. + */ +export function traceInputs(renderGraphInfo: tf.graph.render.RenderGraphInfo) { + // Reset all styling. + d3.selectAll('.input-highlight').classed('input-highlight', false); + d3.selectAll('.non-input').classed('non-input', false); + d3.selectAll('.input-parent').classed('input-parent', false); + d3.selectAll('.input-child').classed('input-child', false); + d3.selectAll('.input-edge-highlight').classed('input-edge-highlight', false); + d3.selectAll('.non-input-edge-highlight') + .classed('non-input-edge-highlight', false); + d3.selectAll('.input-highlight-selected') + .classed('input-highlight-selected', false); + + // Extract currently selected node. Return if input tracing disabled or no + // node is selected. + let selectedNodeSelectorString = 'g.node.selected,g.op.selected'; + let node = d3.select(selectedNodeSelectorString); + let currentNode = undefined; + if (renderGraphInfo && renderGraphInfo.traceInputs && node && node[0] && + node[0][0]) { + currentNode = node[0][0] as Element; + } else { + return; + } + let nodeName = currentNode.getAttribute('data-name'); + let opNodes = _getAllContainedOpNodes(nodeName, renderGraphInfo); + let allTracedNodes = {}; + _.each(opNodes, function(nodeInstance) { + allTracedNodes = + traceAllInputsOfOpNode(renderGraphInfo, nodeInstance, allTracedNodes); + }); + + d3.selectAll(selectedNodeSelectorString) + // Remove the input-highlight from the selected node. + .classed('input-highlight', false) + // Add input-highlight-selected class to selected node, which allows + // treating the selected not as a special case of an input node. + .classed('input-highlight-selected', true) + + // Highlight all parent nodes of each OpNode as input parent to allow + // specific highlighting. + let highlightedNodes = Object.keys(allTracedNodes); + let visibleNodes = + _findVisibleParentsFromOpNodes(renderGraphInfo, highlightedNodes); + _markParentsOfNodes(visibleNodes); + + // Attach class to all non-input nodes and edges for styling. + d3.selectAll( + 'g.node:not(.selected):not(.input-highlight)' + + ':not(.input-parent):not(.input-children)') + .classed('non-input', true) + .each(function(d: RenderNodeInfo) { + // Mark all nodes with the specified name as non-inputs. This + // results in Annotation nodes which are attached to inputs to be + // tagged as well. + let nodeName = d.node.name; + d3.selectAll(`[data-name="${nodeName}"]`).classed('non-input', true); + }); + d3.selectAll('g.edge:not(.input-edge-highlight)') + .classed('non-input-edge-highlight', true); +} + +/** + * Recursively find all op nodes contained by the node identified by the + * provided name. + * @param nodeName The meta or op node of which the OpNode instances are + * required. + * @param renderGraphInfo The rendered graph information object. + * @returns {Array} An array of OpNodeImpl instances. + */ +export function _getAllContainedOpNodes( + nodeName: string, renderGraphInfo: tf.graph.render.RenderGraphInfo) { + let opNodes = []; + + // Get current node. + let node = renderGraphInfo.getNodeByName(nodeName) as tf.graph.GroupNode | + tf.graph.OpNode; + + // If node is already OpNode then return the node plus its input embeddings. + if (node instanceof tf.graph.OpNodeImpl) { + return [node].concat(node.inEmbeddings); + } + + // Otherwise, make recursive call for each node contained by the GroupNode. + let childNodeNames = (node as tf.graph.GroupNode).metagraph.nodes(); + _.each(childNodeNames, function(childNodeName) { + opNodes = + opNodes.concat(_getAllContainedOpNodes(childNodeName, renderGraphInfo)); + }); + + return opNodes; +} + +/** + * When resolving inputs of a node the visible parent node of each input + * node (i.e. the first parent which is rendered to the screen) needs to be + * found, and since such a node may contain several input OpNodes a map + * of the visible parent to all the input OpNodes it contains is provided by + * opNodes. + */ +interface VisibleParent { + visibleParent: Node; + opNodes: OpNode[]; +} + +export function traceAllInputsOfOpNode( + renderGraphInfo: tf.graph.render.RenderGraphInfo, startNode: OpNode, + allTracedNodes: Object) { + // To prevent infinite loops due to cyclical relationships and improving + // performance by tracing OpNode which is input to 2+ nodes only once. + if (allTracedNodes[startNode.name]) { + return allTracedNodes; + } else { + allTracedNodes[startNode.name] = true; + } + // Extract the inputs. + let inputs = startNode.inputs; + // Get visible parent. + let currentVisibleParent = getVisibleParent(renderGraphInfo, startNode); + // Mark as input node. + d3.select(`.node[data-name="${currentVisibleParent.name}"]`) + .classed('input-highlight', true); + + // Find the visible parent of each input. + let visibleInputs = {}; + _.each(inputs, function(nodeInstance) { + let resolvedNode = renderGraphInfo.getNodeByName(nodeInstance.name); + if (resolvedNode === undefined) { + // Node could not be found in rendered Hierarchy, which happens when + // tracing inputs of a SummaryNode. + return; + } + // Ensure node is resolved to OpNode if name collision with Metanode exists. + if (resolvedNode instanceof MetanodeImpl) { + let resolvedNodeName = tf.graph.getStrictName(resolvedNode.name); + resolvedNode = renderGraphInfo.getNodeByName(resolvedNodeName) as OpNode; + } + + let visibleParent = getVisibleParent(renderGraphInfo, resolvedNode); + + // Append OpNode to visible parent entry. + let visibleInputsEntry = visibleInputs[visibleParent.name]; + if (visibleInputsEntry) { + visibleInputsEntry.opNodes.push(resolvedNode); + } else { // Create new entry. + visibleInputs[visibleParent.name] = { + visibleParent: visibleParent, + opNodes: [resolvedNode] + } as VisibleParent; + } + }); + + // Find all parents of the start node. + let startNodeParents = {}; + let indexedStartNodeParents = [currentVisibleParent]; + startNodeParents[currentVisibleParent.name] = { + traced: false, + index: 0, + connectionEndpoints: [] + }; + + let currentNode = currentVisibleParent as Node; + for (let index = 1; currentNode.name !== tf.graph.ROOT_NAME; index++) { + currentNode = currentNode.parentNode; + startNodeParents[currentNode.name] = { + traced: false, + index: index, + connectionEndpoints: [] + }; + indexedStartNodeParents[index] = currentNode; + } + + // Find first mutual parent of each input node and highlight connection. + _.forOwn(visibleInputs, function(visibleParentInfo: VisibleParent, key) { + let nodeInstance = visibleParentInfo.visibleParent; + // Make recursive call for each input-OpNode contained by the visible + // parent. + _.each(visibleParentInfo.opNodes, function(opNode: OpNode) { + allTracedNodes = + traceAllInputsOfOpNode(renderGraphInfo, opNode, allTracedNodes); + }); + + if (nodeInstance.name !== currentVisibleParent.name) { + _createVisibleTrace( + nodeInstance, startNodeParents, indexedStartNodeParents); + } + }); + + return allTracedNodes; +} + +/** + * Colors the edges to connect the passed node to the start node. This is + * done by: + * + * a) Finding the first (visible) common parent in the rendered + * hierarchy. + * NB: There are 2 types of connections: + * 1) Direct connections between node A + * and B, marked below as II, + * 2) Connections from any node A to its parent, A'. Marked below as I and III. + * For type 2 connection you need to know the inner-nested node, the + * direct parent, and the ultimate destination of the connection. + * + * A_parent B_parent + * +--------+ +---------+ + * | | | | + * | +--+ I| II |III+--+ | + * | |A +---------->+B | | + * | +--+ | | +--+ | + * | | | | + * +--------+ +---------+ + * + * + * b) Highlighting the direct connection between the parents of A and B, + * called A_parent and B_parent, s.t. A_parent and B_parent are children of the + * mutual parent of A and B found in a), marked above as II. + * + * c) Highlighting the connection from A to A_parent and B to B_parent + * (through all layers of parents between A and A_parent and B and B_parent, + * respectively). Marked above as I and III. + * + * @param nodeInstance The instance of the node to use as destination node, B. + * @param startNodeParents Map of startNodeParent names to information objects + * about the parent. + * @param indexedStartNodeParents An array of all parents of the start node. + * This is required to find the child of the mutual parent which is a parent + * of the start node. + * @private + */ +function _createVisibleTrace( + nodeInstance: Node, startNodeParents, indexedStartNodeParents: Node[]) { + let currentNode = nodeInstance; + let previousNode = nodeInstance; + + // Ascend through parents until a mutual parent is found with the start + // node. + let destinationParentPairs = []; + while (!startNodeParents[currentNode.name]) { + if (previousNode.name !== currentNode.name) { + destinationParentPairs.push([previousNode, currentNode]); + } + previousNode = currentNode; + currentNode = currentNode.parentNode; + } + + // Connection between nodes is drawn between the parents of each + // respective node, both of which share the mutual parent. + let startNodeIndex = startNodeParents[currentNode.name].index; + let startNodeName = + indexedStartNodeParents[Math.max(startNodeIndex - 1, 0)].name; + + let startNodeTopParentName = startNodeName; + let targetNodeTopParentName = previousNode.name; + + let endNodeName = previousNode.name; + d3.selectAll(`[data-edge="${endNodeName}--${startNodeName}"]`) + .classed('input-edge-highlight', true); + + // Trace up the parents of the input. + _.each(destinationParentPairs, function(value) { + let inner = value[0]; + let outer = value[1]; + let edgeSelector = `[data-edge="${inner.name}--${startNodeTopParentName}` + + `~~${outer.name}~~OUT"]`; + d3.selectAll(edgeSelector).classed('input-edge-highlight', true); + }); + + // Trace up the parents of the start node. + for (let index = 1; index < startNodeIndex; index++) { + let inner = indexedStartNodeParents[index - 1]; + let outer = indexedStartNodeParents[index]; + let edgeSelector = `[data-edge="${targetNodeTopParentName}~~${outer.name}` + + `~~IN--${inner.name}"]`; + d3.selectAll(edgeSelector).classed('input-edge-highlight', true); + } +} + +/** + * Creates map { [name: string] -> Node } of all visible / rendered parents + * of the nodes identified by the node names passed in. + * + * @param renderGraphInfo The information on the rendered graph. + * @param nodeNames String array of node names. + * @returns {[nodeName: string]: Node} + * @private + */ +function _findVisibleParentsFromOpNodes(renderGraphInfo, nodeNames: string[]) { + let visibleParents: {[nodeName: string]: Node} = {}; + _.each(nodeNames, function(nodeName) { + let currentNode = renderGraphInfo.getNodeByName(nodeName); + let visibleParent = getVisibleParent(renderGraphInfo, currentNode); + visibleParents[visibleParent.name] = visibleParent; + }); + + return visibleParents; +} + +/** + * Traverse through the parents of all nodes in the list and mark each + * encountered node as input-parent. + * @param visibleNodes Map of input nodes, have to be visible/rendered when + * called. + * @private + */ +function _markParentsOfNodes(visibleNodes: {[nodeName: string]: Node}) { + _.forOwn(visibleNodes, function(nodeInstance: Node) { + // Mark all parents of the node as input-parents. + let currentNode = nodeInstance; + + while (currentNode.name !== tf.graph.ROOT_NAME) { + let renderedElement = d3.select(`.node[data-name="${currentNode.name}"]`); + // Only mark the element as a parent node to an input if it is not + // marked as input node itself. + if (renderedElement[0][0] && + !renderedElement.classed('input-highlight') && + !renderedElement.classed('selected') && + // OpNode only parent if start node is embedded node, in which case + // the OpNode should be faded as well. + !renderedElement.classed('op')) { + renderedElement.classed('input-parent', true); + } + currentNode = currentNode.parentNode; + } + }); +} + +/** + * Find the parent of the passed in op node which is expanded. This is done + * by going through all parents until the parent's parent is expanded, thus + * finding the first unexpanded parent which is rendered on the screen. + * @param renderGraphInfo The graph info object used to gain access to the + * render info of the parents. + * @param currentNode The node whose parent is to be found. + * @returns Node + */ +export function getVisibleParent( + renderGraphInfo: tf.graph.render.RenderGraphInfo, + currentNode: tf.graph.Node) { + let found = false; + let currentParent = currentNode; + + while (!found) { + // Get parent element, to extract name. + currentNode = currentParent; + currentParent = currentNode.parentNode; + + if (currentParent === undefined) { + found = true; + } else { + let renderNode = renderGraphInfo.getRenderNodeByName(currentParent.name); + // Found if node is rendered on the screen (renderNode truthy), and + // the parent is either expanded (i.e. it is a metanode or seriesnode) + // or the parent is an OpNode in which case currentNode is an embedded + // node which has another OpNode as parent. + if (renderNode && + (renderNode.expanded || currentParent instanceof graph.OpNodeImpl)) { + found = true; + } + } + } // Close while loop. + return currentNode; +} +} // Close module. diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/parser.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/parser.ts new file mode 100644 index 0000000000..04d879ef91 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/parser.ts @@ -0,0 +1,284 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +module tf.graph.parser { + +/** + * Parses a native js value, which can be either a string, boolean or number. + * + * @param value The value to be parsed. + */ +function parseValue(value: string): string|number|boolean { + if (value === 'true') { + return true; + } + if (value === 'false') { + return false; + } + let firstChar = value[0]; + if (firstChar === '"') { + return value.substring(1, value.length - 1); + } + let num = parseFloat(value); + return isNaN(num) ? value : num; +} + +/** + * Fetches a text file and returns a promise of the result. + */ +export function fetchPbTxt(filepath: string): Promise { + return new Promise(function(resolve, reject) { + const request = new XMLHttpRequest(); + request.open('GET', filepath); + request.responseType = 'arraybuffer'; + + request.onerror = () => reject(request.status); + request.onload = () => resolve(request.response); + + request.send(null); + }); +} + +/** + * Fetches the metadata file, parses it and returns a promise of the result. + */ +export function fetchAndParseMetadata(path: string, tracker: ProgressTracker) { + return tf.graph.util + .runTask( + 'Reading metadata pbtxt', 40, + () => { + if (path == null) { + return Promise.resolve(null); + } + return fetchPbTxt(path); + }, + tracker) + .then((arrayBuffer: ArrayBuffer) => { + return tf.graph.util.runAsyncPromiseTask( + 'Parsing metadata.pbtxt', 60, () => { + return arrayBuffer != null ? parseStatsPbTxt(arrayBuffer) : + Promise.resolve(null); + }, tracker); + }); +} + +/** + * Fetches the graph file, parses it and returns a promise of the result. The + * result will be undefined if the graph is empty. + */ +export function fetchAndParseGraphData(path: string, pbTxtFile: Blob, + tracker: ProgressTracker) { + return tf.graph.util + .runTask( + 'Reading graph pbtxt', 40, + () => { + if (pbTxtFile) { + return new Promise(function(resolve, reject) { + let fileReader = new FileReader(); + fileReader.onload = () => resolve(fileReader.result); + fileReader.onerror = () => reject(fileReader.error); + fileReader.readAsArrayBuffer(pbTxtFile); + }); + } else { + return fetchPbTxt(path); + } + }, + tracker) + .then((arrayBuffer: ArrayBuffer) => { + return tf.graph.util.runTask('Parsing graph.pbtxt', 60, () => { + return parseGraphPbTxt(arrayBuffer); + }, tracker); + }); +} + +/** + * Parse a file object in a streaming fashion line by line (or custom delim). + * Can handle very large files. + * @param input The file object as an array buffer. + * @param callback The callback called on each line + * @param chunkSize The size of each read chunk. (optional) + * @param delim The delimiter used to split a line. (optional) + * @returns A promise for when it is finished. + */ +export function streamParse( + arrayBuffer: ArrayBuffer, callback: (string) => void, + chunkSize: number = 1000000, delim: string = '\n'): Promise { + return new Promise(function(resolve, reject) { + let offset = 0; + let bufferSize = arrayBuffer.byteLength - 1; + let data = ''; + + function readHandler(str) { + offset += chunkSize; + let parts = str.split(delim); + let first = data + parts[0]; + if (parts.length === 1) { + data = first; + readChunk(offset, chunkSize); + return; + } + data = parts[parts.length - 1]; + callback(first); + for (let i = 1; i < parts.length - 1; i++) { + callback(parts[i]); + } + if (offset >= bufferSize) { + if (data) { + callback(data); + } + resolve(true); + return; + } + readChunk(offset, chunkSize); + } + + function readChunk(offset: number, size: number) { + const arrayBufferChunk = arrayBuffer.slice(offset, offset + size); + + const blob = new Blob([arrayBufferChunk]); + const file = new FileReader(); + file.onload = (e: any) => readHandler(e.target.result); + file.readAsText(blob); + } + + readChunk(offset, chunkSize); + }); +} + +/** + * Since proto-txt doesn't explicitly say whether an attribute is repeated + * (an array) or not, we keep a hard-coded list of attributes that are known + * to be repeated. This list is used in parsing time to convert repeated + * attributes into arrays even when the attribute only shows up once in the + * object. + */ +const GRAPH_REPEATED_FIELDS: {[attrPath: string]: boolean} = { + 'node': true, + 'node.input': true, + 'node.attr': true, + 'node.attr.value.list.type': true, + 'node.attr.value.shape.dim': true, + 'node.attr.value.tensor.string_val': true, + 'node.attr.value.tensor.tensor_shape.dim': true, + 'node.attr.value.list.shape': true, + 'node.attr.value.list.shape.dim': true, + 'node.attr.value.list.s': true +}; + +const METADATA_REPEATED_FIELDS: {[attrPath: string]: boolean} = { + 'step_stats.dev_stats': true, + 'step_stats.dev_stats.node_stats': true, + 'step_stats.dev_stats.node_stats.output': true, + 'step_stats.dev_stats.node_stats.memory': true, + 'step_stats.dev_stats.node_stats.output.tensor_description.shape.dim': true +}; + +/** + * Parses an ArrayBuffer of a proto txt file into a raw Graph object. + */ +export function parseGraphPbTxt(input: ArrayBuffer): + Promise { + return parsePbtxtFile(input, GRAPH_REPEATED_FIELDS).then(obj => obj['node']); +} + +/** + * Parses an ArrayBuffer of a proto txt file into a StepStats object. + */ +export function parseStatsPbTxt(input: ArrayBuffer): + Promise { + return parsePbtxtFile(input, METADATA_REPEATED_FIELDS) + .then(obj => obj['step_stats']); +} + +/** + * Parses a ArrayBuffer of a proto txt file into javascript object. + * + * @param input The ArrayBuffer or file object implementing slice. + * @param repeatedFields Map (Set) of all the repeated fields, since you can't + * tell directly from the pbtxt if a field is repeated or not. + * @returns The parsed object. + */ +function parsePbtxtFile( + input: ArrayBuffer, + repeatedFields: {[attrPath: string]: boolean}): Promise { + let output: { [name: string]: any; } = {}; + let stack = []; + let path: string[] = []; + let current: { [name: string]: any; } = output; + + function splitNameAndValueInAttribute(line: string) { + let colonIndex = line.indexOf(':'); + let name = line.substring(0, colonIndex).trim(); + let value = parseValue(line.substring(colonIndex + 2).trim()); + return { + name: name, + value: value + }; + } + + /** + * Adds a value, given the attribute name and the host object. If the + * attribute already exists, but is not an array, it will convert it to an + * array of values. + * + * @param obj The host object that holds the attribute. + * @param name The attribute name (key). + * @param value The attribute value. + * @param path A path that identifies the attribute. Used to check if + * an attribute is an array or not. + */ + function addAttribute(obj: Object, name: string, + value: Object|string|number|boolean, path: string[]): void { + // We treat 'node' specially since it is done so often. + let existingValue = obj[name]; + if (existingValue == null) { + obj[name] = path.join('.') in repeatedFields ? [value] : value; + } else if (Array.isArray(existingValue)) { + existingValue.push(value); + } else { + obj[name] = [existingValue, value]; + } + } + + // Run through the file a line at a time. + return streamParse(input, function(line: string) { + if (!line) { + return; + } + line = line.trim(); + + switch (line[line.length - 1]) { + case '{': // create new object + let name = line.substring(0, line.length - 2).trim(); + let newValue: { [name: string]: any; } = {}; + stack.push(current); + path.push(name); + addAttribute(current, name, newValue, path); + current = newValue; + break; + case '}': + current = stack.pop(); + path.pop(); + break; + default: + let x = splitNameAndValueInAttribute(line); + addAttribute(current, x.name, x.value, path.concat(x.name)); + break; + } + }).then(function() { + return output; + }); +} + +} // Close module tf.graph.parser. diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/proto.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/proto.ts new file mode 100644 index 0000000000..eda73e45c3 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/proto.ts @@ -0,0 +1,143 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** + * @fileoverview Interfaces that parallel proto definitions in + * third_party/tensorflow/core/framework/... + * graph.proto + * step_stats.proto + * These should stay in sync. + */ +module tf.graph.proto { + /** + * TensorFlow node definition as defined in the graph.proto file. + */ + export interface NodeDef { + /** Name of the node */ + name: string; + /** List of nodes that are inputs for this node. */ + input: string[]; + /** The name of the device where the computation will run. */ + device: string; + /** The name of the operation associated with this node. */ + op: string; + /** List of attributes that describe/modify the operation. */ + attr: {key: string, value: Object}[]; + } + + /** + * Generic graph as defined in the graph_explorer.proto file. + */ + export interface GenericGraph { + /** List of nodes in the graph */ + node: GenericNode[]; + /** List of nodes in the graph */ + edge: GenericEdge[]; + /** List of attributes that describe/modify the operation. */ + attr: Array<{[key: string]: any}>; + } + + /** + * GenericEdge corresponds to the Edge message in graph_explorer.proto. + */ + export interface GenericEdge { + /** Name of the source node. */ + source: string; + /** Name of the target node. */ + target: string; + /** Attributes of the edge. */ + edge_attr: Array<{[key: string]: any}>; + } + + /** + * GenericNode corresponds to the Node message in graph_explorer.proto. + */ + export interface GenericNode { + /** Name of the node */ + name: string; + /** Attributes of a leaf node or leaf nodes within a metanode. */ + node_attr: Array<{[key: string]: any}>; + /** Attributes of a metanode. */ + metanode_attr: Array<{[key: string]: any}>; + } + + /** + * TensorFlow stats file definition as defined in the stats proto file. + */ + export interface StepStats { + dev_stats: {device: string, node_stats: NodeExecStats[]}[]; + } + + /** + * TensorFlow stats for a node as defined in the step_stats proto file. + */ + export interface NodeExecStats { + node_name: string; + // The next 4 properties are currently stored as string in json + // and must be parsed. + all_start_micros: number; + op_start_rel_micros: number; + op_end_rel_micros: number; + all_end_rel_micros: number; + memory: { + allocator_name: string; + total_bytes: number; // Stored as string in json and should be parsed. + peak_bytes: number; // Stored as string in json and should be parsed. + }[]; + /** Output sizes recorded for a single execution of a graph node */ + output: NodeOutput[]; + timeline_label: string; + scheduled_micros: string; + thread_id: string; + } + + /** + * Description for the output tensor(s) of an operation in the graph as + * defined in the step_stats.proto file. + */ + export interface NodeOutput { + slot: number; // Stored as string in json and should be parsed. + tensor_description: { + /** Data type of tensor elements */ + dtype: string; + /** Shape of the tensor */ + shape: { + /** + * Dimensions of the tensor, such as [{name: 'input', size: 30}, + * {name: 'output', size: 40}] for a 30 x 40 2D tensor. The names + * are optional. The order of entries in 'dim' matters: It indicates + * the layout of the values in the tensor in-memory representation. + */ + dim: { + /** Size of the tensor in that dimension */ + size: number, // Stored as string in json and should be parsed. + /** Optional name of the tensor dimension */ + name?: string + }[]; + }; + /** Information about the size and allocator used for the data */ + allocation_description: { + // The next 2 properties are stored as string in json and + // should be parsed. + /** Total number of bytes requested */ + requested_bytes: number; + /** Total number of bytes allocated, if known */ + allocated_bytes?: number; + /** Name of the allocator used */ + allocator_name: string; + }; + }; + } +} diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/render.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/render.ts new file mode 100644 index 0000000000..474e358ba9 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/render.ts @@ -0,0 +1,1633 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/** + * Package for the Render Hierarchy for TensorFlow graph. + */ +module tf.graph.render { + +export type Point = {x: number, y: number}; + +/** + * Color parameters for op nodes. + */ +export let OpNodeColors = {DEFAULT_FILL: 'white', DEFAULT_STROKE: '#b2b2b2'}; + +/** + * Color parameters for node encoding. + * @type {Object} + */ +export let MetanodeColors = { + /** + * Default fill and stroke to use when no other information is available. + */ + DEFAULT_FILL: '#d9d9d9', + DEFAULT_STROKE: '#a6a6a6', + SATURATION: 0.6, + LIGHTNESS: 0.85, + /** + * Neutral color to use when the node is expanded (used when coloring by + * compute time, memory and device). + */ + EXPANDED_COLOR: '#f0f0f0', + /** + * Standard hue values for node color palette. + */ + HUES: [220, 100, 180, 40, 20, 340, 260, 300, 140, 60], + STRUCTURE_PALETTE(id: number, lightened?: boolean) { + // The code below is a flexible way to computationally create a set + // of colors that go well together. + let hues = MetanodeColors.HUES; + let n = hues.length; + let hue = hues[id % n]; + let m = Math.sin(hue * Math.PI / 360); + let sat = lightened ? 30 : 90 - 60 * m; + let light = lightened ? 95 : 80; + return d3.hsl(hue, .01 * sat, .01 * light).toString(); + }, + DEVICE_PALETTE(index: number): string { + return MetanodeColors.STRUCTURE_PALETTE(index); + }, + XLA_CLUSTER_PALETTE(index: number): string { + return MetanodeColors.STRUCTURE_PALETTE(index); + }, + UNKNOWN: '#eee', + GRADIENT_OUTLINE: '#888' +}; + +/** + * Color parameters for op nodes. + */ +export let SeriesNodeColors = { + DEFAULT_FILL: 'white', + DEFAULT_STROKE: '#b2b2b2' +}; + +/** + * Parameters that affect how the graph is rendered on the screen. + */ +const PARAMS = { + /** + * Whether to extract high degree nodes from the core part of the graph. + */ + enableExtraction: true, + /** + * The minimum number of nodes for a graph to have in order for high in and + * out degree nodes to be extracted in auxiliary. The aim here is to prevent + * nodes from being extracted from small graphs. + */ + minNodeCountForExtraction: 15, + /** + * The minimum in or out degree a node must have in order to be possibly + * extracted. + */ + minDegreeForExtraction: 5, + /** + * Maximum number of control edges a node can have before they aren't + * displayed. + */ + maxControlDegree: 4, + /** + * Maximum in (for outbound bridge paths) or out (for inbound bridge paths) + * degree of a node allowed for a bridge path to be rendered to it from a + * subhierarchy of nodes. Having a max prevents having too many nodes emanate + * from a subhierarchy and crowding up. + */ + maxBridgePathDegree: 4, + /** + * Types patterns for predefined out-extract nodes, which are + * sink-like nodes that will be extracted from the main graph. + */ + outExtractTypes: [ + 'NoOp' // NoOps are sink-like used for managing control dependencies. + ], + + /** + * Types patterns for predefined in-extract nodes, which are + * source-like nodes that will be extracted from the main graph. + */ + inExtractTypes: [], + + /** + * When removing edges from a high degree node, remove all of its edges if + * detachAllEdgesForHighDegree is true. Otherwise remove all in-edges if + * the node has high in-degree, or all out-edges if the node has high + * out-degree. + */ + detachAllEdgesForHighDegree: true, + + /** + * After extracting high in/out degree nodes and predefined + * source-like/sink-like, extract isolated nodes to the side + * if this extractIsolatedNodesWithAnnotationsOnOneSide is true. + */ + extractIsolatedNodesWithAnnotationsOnOneSide: true, + + /** + * Whether to add bridge nodes and edges to the core when building the + * subhierarchy of an expanded metanode. See buildSubhierarchy(). + */ + enableBridgegraph: true, + + /** + * 2 colors, for the minimum and maximum value respectively, whenever we + * have a gradient scale. + */ + minMaxColors: ['#fff5f0', '#fb6a4a'], + + /** + * Maximum number of annotations to be displayed on a node before an + * ellipsis is used. + */ + maxAnnotations: 5 +}; + +/** + * Stores the rendering information, such as x and y coordinates, + * for each node in the graph. + */ +export class RenderGraphInfo { + hierarchy: hierarchy.Hierarchy; + private displayingStats: boolean; + private index: {[nodeName: string]: RenderNodeInfo}; + private renderedOpNames: string[]; + private deviceColorMap: d3.ScaleOrdinal; + private xlaClusterColorMap: d3.ScaleOrdinal; + private memoryUsageScale: d3.ScaleLinear; + private computeTimeScale: d3.ScaleLinear; + /** Scale for the thickness of edges when there is no shape information. */ + edgeWidthScale: + d3.ScaleLinear | d3.ScalePower; + // Since the rendering information for each node is constructed lazily, + // upon node's expansion by the user, we keep a map between the node's name + // and whether the rendering information was already constructed for that + // node. + private hasSubhierarchy: {[nodeName: string]: boolean}; + root: RenderGroupNodeInfo; + traceInputs: Boolean; + + constructor(hierarchy: hierarchy.Hierarchy, displayingStats: boolean) { + this.hierarchy = hierarchy; + this.displayingStats = displayingStats; + this.index = {}; + this.renderedOpNames = []; + + this.computeScales(); + // Maps node name to whether the rendering hierarchy was already + // constructed. + this.hasSubhierarchy = {}; + this.root = new RenderGroupNodeInfo(hierarchy.root); + this.index[hierarchy.root.name] = this.root; + this.renderedOpNames.push(hierarchy.root.name); + this.buildSubhierarchy(hierarchy.root.name); + this.root.expanded = true; + this.traceInputs = false; + } + + computeScales() { + this.deviceColorMap = d3.scaleOrdinal() + .domain(this.hierarchy.devices) + .range(_.map(d3.range(this.hierarchy.devices.length), + MetanodeColors.DEVICE_PALETTE)); + + this.xlaClusterColorMap = + d3.scaleOrdinal() + .domain(this.hierarchy.xlaClusters) + .range(_.map( + d3.range(this.hierarchy.xlaClusters.length), + MetanodeColors.XLA_CLUSTER_PALETTE)); + + let topLevelGraph = this.hierarchy.root.metagraph; + // Find the maximum and minimum memory usage. + let memoryExtent = d3.extent(topLevelGraph.nodes(), + (nodeName, index) => { + let node = topLevelGraph.node(nodeName); + // Some ops don't have stats at all. + if (node.stats != null) { + return node.stats.totalBytes; + } + }); + this.memoryUsageScale = d3.scaleLinear() + .domain(memoryExtent) + .range(PARAMS.minMaxColors); + + // Find also the minimum and maximum compute time. + let computeTimeExtent = d3.extent(topLevelGraph.nodes(), + (nodeName, index) => { + let node = topLevelGraph.node(nodeName); + // Some ops don't have stats at all. + if (node.stats != null) { + return node.stats.getTotalMicros(); + } + }); + this.computeTimeScale = d3.scaleLinear() + .domain(computeTimeExtent) + .range(PARAMS.minMaxColors); + + this.edgeWidthScale = this.hierarchy.hasShapeInfo ? + scene.edge.EDGE_WIDTH_SCALE : + d3.scaleLinear() + .domain([1, this.hierarchy.maxMetaEdgeSize]) + .range([scene.edge.MIN_EDGE_WIDTH, scene.edge.MAX_EDGE_WIDTH]); + } + + /** + * Get a previously created RenderNodeInfo by its node name. + */ + getRenderNodeByName(nodeName: string): RenderNodeInfo { + return this.index[nodeName]; + } + + /** + * Get the underlying node in the hierarchical graph by its name. + */ + getNodeByName(nodeName: string): Node { + return this.hierarchy.node(nodeName); + } + + /** + * Get a previously created RenderNodeInfo for the specified node name, + * or create one if it hasn't been created yet. + */ + getOrCreateRenderNodeByName(nodeName: string): RenderNodeInfo { + // Polymer may invoke this with null. + if (!nodeName) { + return null; + } + + if (nodeName in this.index) { + return this.index[nodeName]; + } + + let node = this.hierarchy.node(nodeName); + // Exit early if the node does not exist in the hierarchy. This can happen + // when a graph is reloaded while the infocard points to a node not visible + // at the top-level. + if (!node) { + return null; + } + let renderInfo = node.isGroupNode ? + new RenderGroupNodeInfo(node) : + new RenderNodeInfo(node); + this.index[nodeName] = renderInfo; + this.renderedOpNames.push(nodeName); + + if (node.stats) { + renderInfo.memoryColor = this.memoryUsageScale(node.stats.totalBytes); + renderInfo.computeTimeColor = + this.computeTimeScale(node.stats.getTotalMicros()); + } + + if (!node.isGroupNode) { + let clusterName = (node as OpNode).xlaCluster; + if (clusterName) { + renderInfo.xlaClusterColor = this.xlaClusterColorMap(clusterName); + } + } + + // We only fade nodes when we're displaying stats. + renderInfo.isFadedOut = this.displayingStats && + !tf.graph.util.hasDisplayableNodeStats(node.stats); + + if (node.isGroupNode) { + // Make a list of tuples (device, proportion), where proportion + // is the fraction of op nodes that have that device. + let pairs = _.pairs((node).deviceHistogram); + if (pairs.length > 0) { + // Compute the total # of devices. + let numDevices = _.sum(pairs, _.last); + renderInfo.deviceColors = _.map(pairs, pair => ({ + color: this.deviceColorMap(pair[0]), + // Normalize to a proportion of total # of devices. + proportion: pair[1] / numDevices + })); + } + } else { + let device = (renderInfo.node).device; + if (device) { + renderInfo.deviceColors = [{ + color: this.deviceColorMap(device), + proportion: 1.0 + }]; + } + } + + return this.index[nodeName]; + } + + /** + * Return the nearest ancestor node, including itself, that is visible + * in the visualization. This method is used so that we can select + * (highlight) a node that isn't drawn yet, by selecting (highlighting) + * its nearest ancestor that has been drawn. + */ + getNearestVisibleAncestor(name: string): string { + let path = getHierarchicalPath(name); + for (let i = 0; i < path.length; i++) { + let nodeName = path[i]; + // Op nodes have expanded set to false by default. + if (!this.getRenderNodeByName(nodeName).expanded) { + return nodeName; + } + } + // Fallthrough. If everything was expanded return the node. + return name; + } + + // TODO(jimbo): Delete this an any code it touches (all deprecated). + setDepth(depth: number): void { + setGroupNodeDepth(this.root, +depth); + } + + /** + * Returns true if the renderNode is an isolated node within its parent node. + */ + isNodeAuxiliary(renderNode: RenderNodeInfo): boolean { + let parentNode = this.getRenderNodeByName( + renderNode.node.parentNode.name); + let found = _.find(parentNode.isolatedInExtract, node => { + return node.node.name === renderNode.node.name; + }); + if (found) { + return true; + } + found = _.find(parentNode.isolatedOutExtract, node => { + return node.node.name === renderNode.node.name; + }); + return !!found; + } + + /** + * Returns a list of ops that have been rendered so far for this graph. More + * ops may later be rendered if the user expands nodes for instance. The list + * returned here can only stay the same size or grow on successive calls. + */ + getNamesOfRenderedOps(): string[] { + return this.renderedOpNames; + } + + buildSubhierarchy(nodeName: string): void { + // Terminate if the rendering hierarchy was already constructed + // for this node. + if (nodeName in this.hasSubhierarchy) { + return; + } + + let renderNodeInfo = this.index[nodeName]; + + // If it is not a meta node or a series node, don't do anything. + if (renderNodeInfo.node.type !== NodeType.META && + renderNodeInfo.node.type !== NodeType.SERIES) { + return; + } + + // At this point we know the rendering information is about a group node. + let renderGroupNodeInfo = renderNodeInfo; + let metagraph = renderGroupNodeInfo.node.metagraph; + let coreGraph = renderGroupNodeInfo.coreGraph; + + // Create render nodes to represent each child from the metagraph. Although + // these will initially be added to the coreGraph, they may later be + // extracted. Also, due to extraction, the coreGraph may contain disjoint + // groups between which there is no visible path (other than annotations). + _.each(metagraph.nodes(), childName => { + + let childRenderInfo = this.getOrCreateRenderNodeByName(childName); + let childNode = childRenderInfo.node; + + coreGraph.setNode(childName, childRenderInfo); + + if (!childNode.isGroupNode) { + _.each((childNode).inEmbeddings, embedding => { + let renderMetaedgeInfo = new RenderMetaedgeInfo(null); + addInAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo, + AnnotationType.CONSTANT); + this.index[embedding.name] = new RenderNodeInfo(embedding); + }); + _.each((childNode).outEmbeddings, embedding => { + let renderMetaedgeInfo = new RenderMetaedgeInfo(null); + addOutAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo, + AnnotationType.SUMMARY); + this.index[embedding.name] = new RenderNodeInfo(embedding); + }); + } + + }); + + // Add render metaedge info for edges in the metagraph. + _.each(metagraph.edges(), edgeObj => { + let metaedge = metagraph.edge(edgeObj); + let renderMetaedgeInfo = new RenderMetaedgeInfo(metaedge); + renderMetaedgeInfo.isFadedOut = + this.index[edgeObj.v].isFadedOut || this.index[edgeObj.w].isFadedOut; + coreGraph.setEdge(edgeObj.v, edgeObj.w, renderMetaedgeInfo); + }); + + if (PARAMS.enableExtraction && + renderGroupNodeInfo.node.type === NodeType.META) { + extractHighDegrees(renderGroupNodeInfo); + } + + // Record that we constructed the rendering hierarchy for this node, so we + // don't construct it another time. + this.hasSubhierarchy[nodeName] = true; + + // Look up the parent node's render information and short circuit if none. + let parentNode = renderGroupNodeInfo.node.parentNode; + if (!parentNode) { + return; + } + let parentNodeInfo = + this.index[parentNode.name]; + + // Utility function for computing the name of a bridge node. + let getBridgeNodeName = (inbound, ...rest) => + rest.concat([inbound ? 'IN' : 'OUT']).join('~~'); + + // Build out the bridgegraph. + let bridgegraph = this.hierarchy.getBridgegraph(nodeName); + + // Look for popular nodes so we can make annotations instead of paths. + let otherCounts = { + // Counts of edges coming INTO other nodes by name (outgoing from self). + in: <{[nodeName: string]: number}> {}, + // Counts of edges going OUT from other nodes by name (coming into self). + out: <{[nodeName: string]: number}> {}, + // Counts of all control edges involving other nodes by name. + control: <{[nodeName: string]: number}> {}, + }; + _.each(bridgegraph.edges(), e => { + // An edge is inbound if its destination node is in the metagraph. + let inbound = !!metagraph.node(e.w); + let otherName = inbound ? e.v : e.w; + let metaedge = bridgegraph.edge(e); + if (!metaedge.numRegularEdges) { + otherCounts.control[otherName] = + (otherCounts.control[otherName] || 0) + 1; + } else if (inbound) { + otherCounts.out[otherName] = (otherCounts.out[otherName] || 0) + 1; + } else { + otherCounts.in[otherName] = (otherCounts.in[otherName] || 0) + 1; + } + }); + + // Add annotations and edges for bridgegraph relationships. + let hierarchyNodeMap = this.hierarchy.getNodeMap(); + _.each(bridgegraph.edges(), bridgeEdgeObj => { + let bridgeMetaedge = bridgegraph.edge(bridgeEdgeObj); + + // Determine whether this bridge edge is incoming by checking the + // metagraph for a node that matches the destination end. + let inbound = !!metagraph.node(bridgeEdgeObj.w); + + // Based on the direction of the edge, one endpoint will be an immediate + // child of this renderNodeInfo, and the other endpoint will be a sibling + // of the parent (or an ancestor further up). + let [childName, otherName] = + inbound ? + [bridgeEdgeObj.w, bridgeEdgeObj.v] : + [bridgeEdgeObj.v, bridgeEdgeObj.w]; + + let childRenderInfo = this.index[childName]; + let otherRenderInfo = this.index[otherName]; + let otherNode = + otherRenderInfo ? + otherRenderInfo.node : + hierarchyNodeMap[otherName]; + + // Determine whether this edge is a control edge between nodes where + // either node is high-degree with respect to control edges. This will + // be a signal to show it as an annotation instead of a bridge edge. + let isHighDegreeControlEdge = !bridgeMetaedge.numRegularEdges && + otherCounts.control[otherName] > PARAMS.maxControlDegree; + + let [, childAnnotations] = + inbound ? + [renderNodeInfo.inAnnotations, childRenderInfo.inAnnotations] : + [renderNodeInfo.outAnnotations, childRenderInfo.outAnnotations]; + + // Don't render a bridge path if the other node has in or out degree above + // a threshold, lest bridge paths emanating out of a metagraph crowd up, + // as was the case for the Fatcat LSTM lstm_1 > lstm_1 metagraph. + let otherDegreeCount = + (inbound ? otherCounts.out : otherCounts.in)[otherName]; + let isOtherHighDegree = otherDegreeCount > PARAMS.maxBridgePathDegree; + + // The adjoining render metaedge info from the parent's coreGraph, if any. + // It will either be a Metaedge involving this node directly, if it + // previously came from a metagraph, or it'll be a Metaedge involving + // a previously created bridge node standing in for the other node. + let adjoiningMetaedge = null; + + // We can only hope to render a bridge path if: + // - bridgegraph paths are enabled, + // - the other node is not too high-degree, + // - the child is in the core (not extracted for being high-degree), and + // - there's a path (in the traversal sense) between child and other. + let canDrawBridgePath = false; + if (PARAMS.enableBridgegraph && + !isOtherHighDegree && + !isHighDegreeControlEdge && + childRenderInfo.isInCore()) { + + // Utility function for finding an adjoining metaedge. + let findAdjoiningMetaedge = targetName => { + let adjoiningEdgeObj: graphlib.EdgeObject = + inbound ? + { v: targetName, w: nodeName } : + { v: nodeName, w: targetName }; + return + parentNodeInfo.coreGraph.edge(adjoiningEdgeObj); + }; + + adjoiningMetaedge = findAdjoiningMetaedge(otherName); + if (!adjoiningMetaedge) { + adjoiningMetaedge = findAdjoiningMetaedge( + getBridgeNodeName(inbound, otherName, parentNode.name)); + } + + canDrawBridgePath = !!adjoiningMetaedge; + } + + // Although dataflow edges are acyclic, control dependency edges may + // actually point 'backwards' in the graph. If this bridgeMetaedge is + // a control dependency, we need to determine whether it's backwards + // pointing so that we render it appropriately. + // + // For instance, say we're rendering a graph with nodes named A/B and Z/Y, + // and we're currently rendering the bridgegraph for A. Further, let's say + // that there was an original BaseEdge from A/B->Z/Y and a CONTROL EDGE + // from Z/Y=>A/B. + // + // +----------------+ + // | A | + // | +-----+ | +------+ + // | | B |>----->|>------->| Z | + // | | | | | | + // | | | * | | | + // | | |<=====<|<=======<| | + // | +-----+ | +------+ + // +----------------+ + // + // When we render the subhierarchy for Metanode A, we'll come across a + // control-only Metaedge in the bridgegraph from Z=>A/B (*). The question + // is whether this edge is backwards. + // + // To answer that question, we follow the chain of adjoining metaedges + // until we reach the topmost one. In this case, that's the control-only + // Metaedge Z=>A in the ROOT's metagraph. We determine that this edge + // is backwards by looking at the topological ordering of ROOT's metagraph + // (which ignores control edges) and seeing that Z comes AFTER A. + // + // The property of being backwards is independent of whether the edge + // is inbound or outbound. In the preceding example, if we were building + // the subhierarchy for Z, we'd find bridge edge Z/Y=>A, walk to its + // topmost adjoining metaedge Z=>A and discover that it's backwards. + let backwards = false; + if (adjoiningMetaedge && !bridgeMetaedge.numRegularEdges) { + // Find the top-most adjoining render metaedge information, and the + // GroupNode whose metagraph must contain the associated metaedge. + let topAdjoiningMetaedge = adjoiningMetaedge; + let topGroupNode = parentNodeInfo.node; + while (topAdjoiningMetaedge.adjoiningMetaedge) { + topAdjoiningMetaedge = topAdjoiningMetaedge.adjoiningMetaedge; + topGroupNode = topGroupNode.parentNode; + } + + // Check against the topological ordering for the top node. The current + // bridge metaedge we're evaluating is backwards if its source comes + // after its destination. + let ordering = this.hierarchy.getTopologicalOrdering(topGroupNode.name); + let e = topAdjoiningMetaedge.metaedge; + backwards = ordering[e.v] > ordering[e.w]; + } + + // Render backwards control edges as annotations. + canDrawBridgePath = canDrawBridgePath && !backwards; + + // If we can't make a bridge path for any reason, then we add an + // annotation instead. + if (!canDrawBridgePath) { + childAnnotations.push(new Annotation( + otherNode, + otherRenderInfo, + new RenderMetaedgeInfo(bridgeMetaedge), + AnnotationType.SHORTCUT, + inbound)); + return; + } + + // At this point, all conditions have been met for drawing a bridge path. + + // Find or create the IN/OUT node representing otherNode. + let bridgeContainerName = getBridgeNodeName(inbound, nodeName); + let bridgeNodeName = getBridgeNodeName(inbound, otherName, nodeName); + let bridgeNodeRenderInfo = coreGraph.node(bridgeNodeName); + if (!bridgeNodeRenderInfo) { + + // Find or create the directional container for the bridge node. + let bridgeContainerInfo = coreGraph.node(bridgeContainerName); + if (!bridgeContainerInfo) { + let bridgeContainerNode: BridgeNode = { + // Important node properties. + name: bridgeContainerName, + type: NodeType.BRIDGE, + // Unused node properties. + isGroupNode: false, + cardinality: 0, + parentNode: null, + stats: null, + include: InclusionType.UNSPECIFIED, + // BridgeNode properties. + inbound: inbound, + nodeAttributes: {}, + }; + bridgeContainerInfo = + new RenderNodeInfo(bridgeContainerNode); + this.index[bridgeContainerName] = bridgeContainerInfo; + coreGraph.setNode(bridgeContainerName, bridgeContainerInfo); + } + + let bridgeNode: BridgeNode = { + // Important node properties. + name: bridgeNodeName, + type: NodeType.BRIDGE, + // Unimportant node properties. + isGroupNode: false, + cardinality: 1, + parentNode: null, + stats: null, + include: InclusionType.UNSPECIFIED, + // BridgeNode properties. + inbound: inbound, + nodeAttributes: {}, + }; + bridgeNodeRenderInfo = new RenderNodeInfo(bridgeNode); + this.index[bridgeNodeName] = bridgeNodeRenderInfo; + coreGraph.setNode(bridgeNodeName, bridgeNodeRenderInfo); + + // Set bridgeNode to be a graphlib child of the container node. + coreGraph.setParent(bridgeNodeName, bridgeContainerName); + bridgeContainerInfo.node.cardinality++; + } + + // Create and add a bridge render metaedge. + let bridgeRenderMetaedge = + new RenderMetaedgeInfo(bridgeMetaedge); + bridgeRenderMetaedge.adjoiningMetaedge = adjoiningMetaedge; + inbound ? + coreGraph.setEdge(bridgeNodeName, childName, bridgeRenderMetaedge) : + coreGraph.setEdge(childName, bridgeNodeName, bridgeRenderMetaedge); + + }); // End _.each(bridgegraph.edges). + + // For each bridge container (IN and/or OUT), add structural edges between + // terminal nodes and that container. A terminal node is one which has no + // non-bridge edges in the direction of the container. + // + // For example, consider a Metanode A which contains two child nodes A/B + // and A/C. Let's say it has one edge in the metagraph from A/B->A/C, and + // one edge in the bridgegraph from Z->A/C. + // + // At this point, we've added a container bridge node IN to house all + // incoming bridge nodes. We've also added a bridge node Z' (with parent IN) + // to A, and a bridge edge from Z'->C. + // + // +----------------------+ + // | A +---+ | + // | +------>| C | | + // | | +---+ | + // | | ^ | + // | | | | + // | | +----|----+ | + // | | | IN | | | + // | +---+ | +---+ | | + // | | B | | | Z'| | | + // | +---+ | +---+ | | + // | +---------+ | + // +----------------------+ + // + // With no other help, dagre would lay out B and Z' on the same level, + // because both of them have no incoming edges. In other words, B is a + // terminal node in the INCOMING direction. + // + // But we want to force dagre to lay out Z' (and everything in IN) lower + // than all non-bridge nodes, so that there's enough room for the bridge + // edges after they've been adjusted to meet up with paths coming in from + // outside. + // + // To force Z' (and all other bridge nodes) to be lowest in the graph, we + // identify terminal nodes like B and give them structural edges to + // a new structural bridge node S which we add to IN. + // + // +----------------------+ + // | A +---+ | + // | +--->| C | | + // | | +---+ | + // | +---+ ^ | + // | | B | | | + // | +---+ | | + // | ^ | | + // | | | | + // | +----|------|----+ | + // | |IN | | | | + // | | +---+ +---+ | | + // | | | S | | Z'| | | + // | | +---+ +---+ | | + // | +----------------+ | + // +----------------------+ + // + // This ensures that dagre will lay out the bridge containers strictly at + // the ends of the graph. The structural edges will never be seen in the + // visualization except as a debugging aid. + _.each([true, false], inbound => { + let bridgeContainerName = getBridgeNodeName(inbound, nodeName); + let bridgeContainerInfo = coreGraph.node(bridgeContainerName); + if (!bridgeContainerInfo) { + return; + } + _.each(coreGraph.nodes(), childName => { + // Short-circuit if this child is a bridge node or it's not a terminal + // node in the direction we're interested in. + let childNodeInfo = coreGraph.node(childName); + if (childNodeInfo.node.type === NodeType.BRIDGE) { + return; + } + let isTerminal = inbound ? + !coreGraph.predecessors(childName).length : + !coreGraph.successors(childName).length; + if (!isTerminal) { + return; + } + + // Find or create a bridge node in the container for all structural + // metaedges. It would have been nice to skip this step and simply + // set a metaedge between the terminal node and the container node, but + // in that case, something about the graph upsets dagre.layout()'s + // longestPath algorithm (was getting errors due to an undefined). + let structuralNodeName = + getBridgeNodeName(inbound, nodeName, 'STRUCTURAL_TARGET'); + let structuralRenderInfo = coreGraph.node(structuralNodeName); + if (!structuralRenderInfo) { + let bridgeNode: BridgeNode = { + // Important Node properties. + name: structuralNodeName, + type: NodeType.BRIDGE, + // Unimportant Node properties. + isGroupNode: false, + cardinality: 1, + parentNode: null, + stats: null, + include: InclusionType.UNSPECIFIED, + // BridgeNode properties. + inbound: inbound, + nodeAttributes: {}, + }; + structuralRenderInfo = new RenderNodeInfo(bridgeNode); + structuralRenderInfo.structural = true; + this.index[structuralNodeName] = structuralRenderInfo; + coreGraph.setNode(structuralNodeName, structuralRenderInfo); + bridgeContainerInfo.node.cardinality++; + coreGraph.setParent(structuralNodeName, bridgeContainerName); + } + + // Create the structural Metaedge and insert it. + let structuralMetaedgeInfo = new RenderMetaedgeInfo(null); + structuralMetaedgeInfo.structural = true; + structuralMetaedgeInfo.weight--; // Reduce weight for dagre layout. + inbound ? + coreGraph.setEdge( + structuralNodeName, childName, structuralMetaedgeInfo) : + coreGraph.setEdge( + childName, structuralNodeName, structuralMetaedgeInfo); + }); + }); + } +} + +/** + * A class for rendering annotation object which contains label + * about the node embedded as annotation, type of annotation and the location + * of both the annotation's node and edge. + * + * Annotation objects include embedded constants, embedded summary, and + * edge shortcuts. + */ +export class Annotation { + node: Node; + renderNodeInfo: RenderNodeInfo; + renderMetaedgeInfo: RenderMetaedgeInfo; + annotationType: AnnotationType; + /** + * Center position of annotation relative to the host + * node's center x. + */ + dx: number; + /** + * Center position of annotation relative to the host + * node's center y. + */ + dy: number; + width: number; + height: number; + /** + * The names of nodes on either side of this edge. + */ + v: string; + w: string; + /** + * A flag whether it is an in-annotation (if true) or + * out-annotation (if false). + */ + isIn: boolean; + /** Label horizontal offset from the end of the node shape */ + labelOffset: number; + /** + * Array of points for edges from the annotation to its host + * node. Each point contains the point location, relative to + * the host node's center. + */ + points: {dx: number, dy: number}[]; + + /** + * Creates a new Annotation. + * + * @param node The underlying node this annotation points to. + * @param renderNodeInfo The render information for the underlying node + * this annotation points to. This can be null if the annotation + * denotes an embedding (constant, summary), in which case we + * use the node property. + * @param renderMetaedgeInfo The render information for the edge associated + * with the annotation. + * @param type The type of the annotation. + * @param isIn True if it is an in-annotation. False if it is an + * out-annotation. + */ + constructor(node: Node, renderNodeInfo: RenderNodeInfo, + renderMetaedgeInfo: RenderMetaedgeInfo, type: AnnotationType, + isIn: boolean) { + this.node = node; + this.renderNodeInfo = renderNodeInfo; + this.renderMetaedgeInfo = renderMetaedgeInfo; + this.annotationType = type; + // Properties specified by layout + this.dx = 0; + this.dy = 0; + this.width = 0; + this.height = 0; + // Properties needed for generating an ID for the edge's path element if + // this annotation is associated with a metaedge. + if (renderMetaedgeInfo && renderMetaedgeInfo.metaedge) { + this.v = renderMetaedgeInfo.metaedge.v; + this.w = renderMetaedgeInfo.metaedge.w; + } + + this.isIn = isIn; + this.points = []; + } +}; + +export enum AnnotationType {SHORTCUT, CONSTANT, SUMMARY, ELLIPSIS}; + +/** + * Manages a list of annotations. Two will be used for each + * RenderNodeInfo, one for in annotations and one for out annotations. + */ +export class AnnotationList { + /** + * List of visually drawable annotations, may include an ellipses annotation + * if the number added exceeds the number specified by maxAnnotations. + */ + list: Annotation[]; + + /** + * Set of nodes which have been added as annotations to this list, so we can + * prevent duplicates. + */ + nodeNames: { [nodeName: string]: boolean }; + + constructor() { + this.list = []; + this.nodeNames = {}; + } + + /** + * Append an annotation to the list, or a stand-in ellipsis annotation instead + * if this would make it too many. + */ + push(annotation: Annotation): void { + if (annotation.node.name in this.nodeNames) { + return; // Skip duplicate annotation. + } + this.nodeNames[annotation.node.name] = true; + + if (this.list.length < PARAMS.maxAnnotations) { + this.list.push(annotation); + return; + } + + let lastAnnotation = this.list[this.list.length - 1]; + if (lastAnnotation.annotationType === AnnotationType.ELLIPSIS) { + let ellipsisNode = lastAnnotation.node; + ellipsisNode.setNumMoreNodes(++ellipsisNode.numMoreNodes); + return; + } + + let ellipsisNode = new tf.graph.EllipsisNodeImpl(1); + this.list.push(new Annotation(ellipsisNode, + new RenderNodeInfo(ellipsisNode), null, + AnnotationType.ELLIPSIS, annotation.isIn)); + } +} + +/** + * Contains rendering information about a node in the hierarchical graph. + */ +export class RenderNodeInfo { + /** Reference to the original underlying Node from the hierarchical graph. */ + node: Node; + /** Whether the node is expanded or not. */ + expanded: boolean; + /** + * List of rendering information about in-annotations like constants and + * shortcuts to high-degree nodes. + */ + inAnnotations: AnnotationList; + /** + * List of rendering information about out-annotations (e.g. summary nodes) + */ + outAnnotations: AnnotationList; + + // --- Params specified by layout --- // + + /** Center x position */ + x: number; + /** Center y position */ + y: number; + /** + * Total width of the node's shape, including in- and out-annotations. This + * property is used by dagre to layout the graph. + */ + width: number; + /** + * Total height of the node's shape, including in- and out-annotations. This + * property is used by dagre to layout the graph. + */ + height: number; + /** + * Size of the main box of the node, excluding in- and out-annotations. This + * property is used to draw the rectangle/ellipse shape denoting the node. + */ + coreBox: { + width: number, + height: number, + }; + + /** Width of the bounding box for all in-annotations. */ + inboxWidth: number; + /** Width of the bounding box for all out-annotations. */ + outboxWidth: number; + /** + * Whether the node should be excluded from the scene. + * This is only used when there are too many items in a series so we only + * want to include top N ones. + */ + // TODO(jimbo): Now that series rendering is non-recursive, remove this and + // all its uses from the code base. + excluded: boolean; + + // --- Params used in drawing the bridge paths --- // + + /** + * All bridge nodes are meant to be invisible, but whereas most represent a + * relationship from the underlying graph hierarchy, some exist solely for + * layout reasons. Specifically, those bridge nodes which have only structural + * rendering metaedges. + */ + structural: boolean; + + // --- Params for the size of the node box --- // + + /** Label vertical offset from the center of node shape */ + labelOffset: number; + /** Rectangle radius (for making rounded rectangle) */ + radius: number; + + // --- Params for expanded node --- // + + /** Label height for expanded node. */ + labelHeight: number; + // Paddings between inner subscene and the border of the expanded node. + paddingTop: number; + paddingLeft: number; + paddingRight: number; + paddingBottom: number; + + /** + * Whether a node is extracted as source-like (having high out-degree or + * matching predefined in-extract pattern.) + */ + isInExtract: boolean; + /** + * Whether a node is extracted as sink-like (having high in-degree or matching + * predefined out-extract pattern.) + */ + isOutExtract: boolean; + + /** + * List of (color, proportion) tuples based on the proportion of devices of + * its children. If this node is an op node, this list will have only one + * color with proportion 1.0. + */ + deviceColors: Array<{color: string, proportion: number}>; + + /** + * Color according to the XLA cluster of this node. + */ + xlaClusterColor: string; + + /** + * Color according to the memory usage of this node. + */ + memoryColor: string; + + /** + * Color according to the compute time of this node. + */ + computeTimeColor: string; + + /** + * Whether this node is faded out. Used when displaying stats. + */ + isFadedOut: boolean; + + constructor(node: Node) { + this.node = node; + this.expanded = false; + this.inAnnotations = new AnnotationList(); + this.outAnnotations = new AnnotationList(); + // Params specified by layout + this.x = 0; + this.y = 0; + this.width = 0; + this.height = 0; + this.inboxWidth = 0; + this.outboxWidth = 0; + + this.excluded = false; + + // Params for bridge paths. + this.structural = false; + + // Params for node box. + this.labelOffset = 0; + this.radius = 0; + + // Params for expanded node + this.labelHeight = 0; + this.paddingTop = 0; + this.paddingLeft = 0; + this.paddingRight = 0; + this.paddingBottom = 0; + this.isInExtract = false; + this.isOutExtract = false; + this.coreBox = {width: 0, height: 0}; + + // By default, we don't fade nodes out. Default to false for safety. + this.isFadedOut = false; + } + + isInCore(): boolean { + return !this.isInExtract && !this.isOutExtract; + } +} + +/** + * Contains rendering information about a Metaedge from the underlying + * hierarchical graph. It may be from either a metagraph or a bridgegraph. + */ +export class RenderMetaedgeInfo { + /** + * Reference to the original underlying Metaedge from the hierarchical graph, + * if any. This will be null for the edges which connect OpNodes to their + * embeddings, for example. + */ + metaedge: Metaedge; + + /** + * Reference to the adjoining RenderMetaedgeInfo from the parent's + * coreGraph. This is used during layout to determine the point at which this + * edge should touch the node's bounding box. This property will be null for + * edges which terminate at a node on both ends (all non-bridge edges). + */ + adjoiningMetaedge: RenderMetaedgeInfo; + + /** + * Most of the time, a RenderMetaedgeInfo object represents a real + * edge between nodes in the underlying graph structure. But sometimes, an + * edge only exists for layout purposes. These structural edges are added + * during buildSubhierarchy() to force dagre.layout() to put bridge nodes + * at the ends of the flow. + * @see buildSubhierarchy() + */ + structural: boolean; + + /** + * Weight of the edge, used by dagre when deciding how important an edge is. + * Edges with higher weight are made shorter and straighter. The default + * dagre uses is 1. + */ + weight: number; + + /** + * X and Y coordinate pairs of the points in the path of the edge. + * @see tf.graph.node.subsceneAdjustPaths + */ + points: Point[]; + + /** + * D3 selection of the group containing the path that displays this edge. + */ + edgeGroup: d3.Selection; + + /** Id of the used as a start-marker for the edge path. */ + startMarkerId: string; + + /** Id of the used as an end-marker for the edge path. */ + endMarkerId: string; + + /** + * Whether this edge is faded out. Used for fading out unused edges when + * displaying run statistics. + */ + isFadedOut: boolean; + + constructor(metaedge: Metaedge) { + this.metaedge = metaedge; + this.adjoiningMetaedge = null; + this.structural = false; + this.weight = 1; + this.isFadedOut = false; + } +} + +function addInAnnotation(node: RenderNodeInfo, predecessor: Node, + predecessorRenderInfo: RenderNodeInfo, + edge: RenderMetaedgeInfo, type: AnnotationType): void { + let annotation = new Annotation(predecessor, predecessorRenderInfo, edge, + type, true); + node.inAnnotations.push(annotation); +} + +function addOutAnnotation(node: RenderNodeInfo, successor: Node, + successorRenderInfo: RenderNodeInfo, edge: RenderMetaedgeInfo, + type: AnnotationType): void { + let annotation = new Annotation(successor, successorRenderInfo, edge, + type, false); + node.outAnnotations.push(annotation); +} + +function setGraphDepth(graph: graphlib.Graph, + depth: number) { + _.each(graph.nodes(), nodeName => { + let child = graph.node(nodeName); + child.expanded = depth > 1; // set all child of depth 1 to collapsed + if (depth > 0) { + switch (child.node.type) { + case NodeType.META: + case NodeType.SERIES: + setGroupNodeDepth(child, depth - 1); + break; + // Do nothing for leaf + } + } + }); +}; + +export class RenderGroupNodeInfo extends RenderNodeInfo { + node: GroupNode; + /** + * The core graph is derived from the underlying node's metagraph, minus + * the extracted source-like and sink-like nodes. + */ + coreGraph: graphlib.Graph; + /** Size of the bounding box for a metanode's isolated in-extract children. */ + inExtractBox: {width: number, height: number}; + /** + * Size of the bounding box for a metanode's isolated out-extract children. + */ + outExtractBox: {width: number, height: number}; + /** Array of isolated in-extract nodes. */ + isolatedInExtract: RenderNodeInfo[]; + /** Array of isolated out-extract nodes. */ + isolatedOutExtract: RenderNodeInfo[]; + + constructor(groupNode: GroupNode) { + super(groupNode); + let metagraph = groupNode.metagraph; + let gl = metagraph.graph(); + this.coreGraph = + createGraph( + gl.name, GraphType.CORE, { compound: true }); + this.inExtractBox = {width: 0, height: 0}; + this.outExtractBox = {width: 0, height: 0}; + this.isolatedInExtract = []; + this.isolatedOutExtract = []; + } +} + +function setGroupNodeDepth(renderInfo: RenderGroupNodeInfo, + depth: number): void { + if (renderInfo.coreGraph) { + setGraphDepth(renderInfo.coreGraph, depth); + } +} + +/** + * Remove an edge from the graph and add annotations to both ends of the edge. + * + * @param The core graph. + * @param v Source name. + * @param w Sink name. + */ +function createShortcut( + graph: graphlib.Graph, + v: string, w: string) { + let src = graph.node(v); + let sink = graph.node(w); + let edge = graph.edge(v, w); + + // If either of the nodes is explicitly included in the main graph and + // both nodes are in the main graph then do not create the shortcut + // and instead keep the real edge. + if ((src.node.include === InclusionType.INCLUDE || + sink.node.include === InclusionType.INCLUDE) && + src.node.include !== InclusionType.EXCLUDE && + sink.node.include !== InclusionType.EXCLUDE) { + return; + } + + // Add each annotation. + addOutAnnotation(src, sink.node, sink, edge, AnnotationType.SHORTCUT); + addInAnnotation(sink, src.node, src, edge, AnnotationType.SHORTCUT); + + // Remove the edge from the core graph. + graph.removeEdge(v, w); +} + +/** + * Remove edges from a node, and set its isOutExtract property to true, + * and remove the node and move it to isolatedOutExtract. + * + * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its + * edges. Otherwise, only extract all in-edges. + */ +function makeOutExtract(renderNode: RenderGroupNodeInfo, n: string, + forceDetach?: boolean) { + let graph = renderNode.coreGraph; + let child = graph.node(n); + child.isOutExtract = true; + + _.each(graph.predecessors(n), (p, index) => { + createShortcut(graph, p, n); + }); + + if (PARAMS.detachAllEdgesForHighDegree || forceDetach) { + _.each(graph.successors(n), (s, index) => { + createShortcut(graph, n, s); + }); + } + + // Remove the node from the core graph if it no longer has neighbors. + if (graph.neighbors(n).length === 0) { + child.node.include = InclusionType.EXCLUDE; + renderNode.isolatedOutExtract.push(child); + graph.removeNode(n); + } +} + +/** + * Remove edges from a node, set its isInExtract property to true, + * and remove the node and move it to isolatedInExtract. + * + * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its + * edges. Otherwise, only remove all out-edges. + */ +export function makeInExtract(renderNode: RenderGroupNodeInfo, n: string, + forceDetach?: boolean) { + let graph = renderNode.coreGraph; + let child = graph.node(n); + child.isInExtract = true; + + _.each(graph.successors(n), (s, index) => { + createShortcut(graph, n, s); + }); + + if (PARAMS.detachAllEdgesForHighDegree || forceDetach) { + _.each(graph.predecessors(n), (p, index) => { + createShortcut(graph, p, n); + }); + } + + // Remove the node from the core graph if it no longer has neighbors. + if (graph.neighbors(n).length === 0) { + child.node.include = InclusionType.EXCLUDE; + renderNode.isolatedInExtract.push(child); + graph.removeNode(n); + } +} + +/** + * Check whether the node's type is a member of the given list of types. + * + * @param node Node. + * @param types List of type to match. + */ +function hasTypeIn(node: Node, types: string[]): boolean { + if (node.type === NodeType.OP) { + for (let i = 0; i < types.length; i++) { + if ((node).op === types[i]) { return true; } + } + } else if (node.type === NodeType.META) { + let rootOpNode = (node).getRootOp(); + if (rootOpNode) { + for (let i = 0; i < types.length; i++) { + if (rootOpNode.op === types[i]) { return true; } + } + } + } + return false; +} + +/** Move nodes that are specified to be excluded out of the core graph. */ +function extractSpecifiedNodes(renderNode: RenderGroupNodeInfo) { + let graph = renderNode.coreGraph; + _.each(graph.nodes(), n => { + let renderInfo = graph.node(n); + if (renderInfo.node.include === InclusionType.EXCLUDE) { + if (renderNode.coreGraph.outEdges(n).length > + renderNode.coreGraph.inEdges(n).length) { + makeOutExtract(renderNode, n, true); + } else { + makeInExtract(renderNode, n, true); + } + } + }); +} + +/** Remove edges from pre-defined out-extract patterns */ +function extractPredefinedSink(renderNode: RenderGroupNodeInfo) { + let graph = renderNode.coreGraph; + _.each(graph.nodes(), n => { + let renderInfo = graph.node(n); + if (renderInfo.node.include !== InclusionType.UNSPECIFIED) { + return; + } + if (hasTypeIn(renderInfo.node, PARAMS.outExtractTypes)) { + makeOutExtract(renderNode, n); + } + }); +} + +/** Remove edges from pre-defined in-extract patterns */ +function extractPredefinedSource(renderNode) { + let graph = renderNode.coreGraph; + _.each(graph.nodes(), n => { + let renderInfo = graph.node(n); + if (renderInfo.node.include !== InclusionType.UNSPECIFIED) { + return; + } + if (hasTypeIn(renderInfo.node, PARAMS.inExtractTypes)) { + makeInExtract(renderNode, n); + } + }); +} + +/** Extract nodes deemed to have either high in-degree or high out-degree. */ +function extractHighInOrOutDegree(renderNode: RenderGroupNodeInfo) { + let graph = renderNode.coreGraph; + + // Create mappings from node to in and out degrees. Count the number of valid + // nodes along the way. + let nodeToInDegree = {}; + let nodeToOutDegree = {}; + let validNodeCount = 0; + _.each(graph.nodes(), currentNode => { + if (graph.node(currentNode).node.include !== InclusionType.UNSPECIFIED) { + // This node is not included in the first place. + return; + } + + // Count the in and out degrees based on only regular edges, unless there + // are no regular edges, in which case use the number of control edges. + // This is done so that control edges don't affect if nodes are extracted + // from the core graph, unless the node is only used for control. + let inDegree = + _.reduce(graph.predecessors(currentNode), (inDegree, pred) => { + let metaedge = graph.edge(pred, currentNode).metaedge; + return inDegree + (metaedge.numRegularEdges ? 1 : 0); + }, 0); + if (inDegree === 0 && graph.predecessors(currentNode).length > 0) { + inDegree = graph.predecessors(currentNode).length; + } + + let outDegree = + _.reduce(graph.successors(currentNode), (outDegree, succ) => { + let metaedge = graph.edge(currentNode, succ).metaedge; + return outDegree + (metaedge.numRegularEdges ? 1 : 0); + }, 0); + if (outDegree === 0 && graph.successors(currentNode).length > 0) { + outDegree = graph.successors(currentNode).length; + } + + // Store the in and out degrees of this node to avoid recomputing. + nodeToInDegree[currentNode] = inDegree; + nodeToOutDegree[currentNode] = outDegree; + validNodeCount++; + }); + + if (validNodeCount < PARAMS.minNodeCountForExtraction) { + // This graph has few nodes. Do not extract any nodes. + return; + } + + // We only extract if the node has a min in or out degree greater than this. + let minUpperBound = PARAMS.minDegreeForExtraction - 1; + + // Mark for extraction nodes with in-degree > Q3 + (Q3 - Q1). + let q3Index = Math.round(validNodeCount * 0.75); + let q1Index = Math.round(validNodeCount * 0.25); + let sortedByInDegree = Object.keys(nodeToInDegree).sort((node0, node1) => { + return nodeToInDegree[node0] - nodeToInDegree[node1]; + }); + let inDegreeQ3 = nodeToInDegree[sortedByInDegree[q3Index]]; + let inDegreeQ1 = nodeToInDegree[sortedByInDegree[q1Index]]; + let inDegreeUpperBound = inDegreeQ3 + inDegreeQ3 - inDegreeQ1; + // Only extract if the upper bound is high enough. + inDegreeUpperBound = Math.max(inDegreeUpperBound, minUpperBound); + for (let i = validNodeCount - 1; + nodeToInDegree[sortedByInDegree[i]] > inDegreeUpperBound; i--) { + // Extract a high in-degree node. + makeInExtract(renderNode, sortedByInDegree[i]); + } + + // Mark for extraction nodes with out-degree > Q3 + (Q3 - Q1) * 4. + let sortedByOutDegree = Object.keys(nodeToOutDegree).sort((node0, node1) => { + return nodeToOutDegree[node0] - nodeToOutDegree[node1]; + }); + let outDegreeQ3 = nodeToOutDegree[sortedByOutDegree[q3Index]]; + let outDegreeQ1 = nodeToOutDegree[sortedByOutDegree[q1Index]]; + // The upper bound for extracting out-degree nodes is higher than that for + // extracting in-degree ones (Note the "* 4") because, in practice, some + // graphs look worse with a smaller out-degree bound. For instance, a smaller + // out-degree bound removes the convolution nodes from cifar 10 train's graph. + let outDegreeUpperBound = outDegreeQ3 + (outDegreeQ3 - outDegreeQ1) * 4; + // Only extract if the upper bound is high enough. + outDegreeUpperBound = Math.max(outDegreeUpperBound, minUpperBound); + for (let i = validNodeCount - 1; + nodeToOutDegree[sortedByOutDegree[i]] > outDegreeUpperBound; i--) { + let node = graph.node(sortedByOutDegree[i]); + if (!node || node.isInExtract) { + // This node has already been extracted due to high in-degree. It might + // have been removed from the graph in general (during in-degree + // extraction) due to a lack of neighbors. Do not extract this node twice. + continue; + } + + // Extract a high out-degree node that has not already been extracted. + makeOutExtract(renderNode, sortedByOutDegree[i]); + } +} + +/** Remove control edges from nodes that have too many control edges */ +function removeControlEdges(renderNode: RenderGroupNodeInfo) { + let graph = renderNode.coreGraph; + + // Collect control edges into a map by node name. + let map = <{[nodeName: string]: graphlib.EdgeObject[]}>{}; + _.each(graph.edges(), e => { + if (!graph.edge(e).metaedge.numRegularEdges) { + (map[e.v] = map[e.v] || []).push(e); + (map[e.w] = map[e.w] || []).push(e); + } + }); + + // For each node with too many control edges, turn them into annotations. + _.each(map, (edges, nodeName) => { + if (edges.length > PARAMS.maxControlDegree) { + _.each(edges, e => createShortcut(graph, e.v, e.w)); + } + }); +} + +/** + * Given an integer, picks a hue that is far apart from other colors. + * The formula for picking color that avoid collision is: + * hue = (color range * golden ratio * index) % color range + */ +export function mapIndexToHue(id: number): number { + let GOLDEN_RATIO = 1.61803398875; + // Hue of 0 is reserved for the gray nodes. + let MIN_HUE = 1; + let MAX_HUE = 359; + let COLOR_RANGE = MAX_HUE - MIN_HUE; + return MIN_HUE + ((COLOR_RANGE * GOLDEN_RATIO * id) % COLOR_RANGE); +}; + +/** + * Remove edges and add to annotation instead. + * + * For root node, consider predefined types for source and sink. + * We do not extract predefined type from non-root so that Variables and the + * sgd node (op type = 'NoOp') do not get extract from inside own group. + * + * The order of extraction is important here as swapping the order can totally + * screw up the graph layout. + * + * @param {Render.Node} renderNode Node to manipulate. + */ +function extractHighDegrees(renderNode: RenderGroupNodeInfo) { + + extractSpecifiedNodes(renderNode); + + if (PARAMS.outExtractTypes) { + extractPredefinedSink(renderNode); + } + + // This has to come before extract high in-degree to protect the core part + // that takes many variables. + if (PARAMS.inExtractTypes) { + extractPredefinedSource(renderNode); + } + + extractHighInOrOutDegree(renderNode); + + if (PARAMS.maxControlDegree) { + removeControlEdges(renderNode); + } + + // Extract isolated nodes, which can be + // (1) source-like and sink-like nodes that are not originally isolated but + // become isolated after further removal. + // (2) isolated nodes with annotations on one-side. These might be either + // - nodes that originally have high out-degree but because we remove + // high in-degree nodes first, they no longer have high in-degree when + // we check. (Detecting all high-degree before removing also leads to + // another problem.) + // - nodes that do not have high degree, but their neighbors are all + // extracted, so it might make sense to extract them too. + + let graph = renderNode.coreGraph; + _.each(graph.nodes(), n => { + let child = graph.node(n); + let degree = graph.neighbors(n).length; + if (child.node.include !== InclusionType.UNSPECIFIED) { + return; + } + if (degree === 0) { + let hasOutAnnotations = child.outAnnotations.list.length > 0; + let hasInAnnotations = child.inAnnotations.list.length > 0; + + if (child.isInExtract) { // Is source-like. + // This case only happens if detachAllEdgesForHighDegree is false. + // (Otherwise all source-like nodes are all isolated already.) + renderNode.isolatedInExtract.push(child); + child.node.include = InclusionType.EXCLUDE; + graph.removeNode(n); + } else if (child.isOutExtract) { // Is sink-like. + // This case only happens if detachAllEdgesForHighDegree is false. + // // (Otherwise all sink-like nodes are all isolated already.) + renderNode.isolatedOutExtract.push(child); + child.node.include = InclusionType.EXCLUDE; + graph.removeNode(n); + } else if (PARAMS.extractIsolatedNodesWithAnnotationsOnOneSide) { + if (hasOutAnnotations && !hasInAnnotations) { + child.isInExtract = true; // for ones with high out-annotations + renderNode.isolatedInExtract.push(child); + child.node.include = InclusionType.EXCLUDE; + graph.removeNode(n); + } else if (hasInAnnotations && !hasOutAnnotations) { + child.isOutExtract = true; // for ones with high in-annotations + renderNode.isolatedOutExtract.push(child); + child.node.include = InclusionType.EXCLUDE; + graph.removeNode(n); + } else { + // if a low degree node has both in- & out- annotations, do nothing + // because it is unclear which side it should go to. + } + } + } + }); +} +} // close module tf.graph.render diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/scene.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/scene.ts new file mode 100644 index 0000000000..06f03e910a --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/scene.ts @@ -0,0 +1,680 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +module tf.graph.scene { + const svgNamespace = 'http://www.w3.org/2000/svg'; + + /** Enums element class of objects in the scene */ + export let Class = { + Node: { + // element that contains nodes. + CONTAINER: 'nodes', + // element that contains detail about a node. + GROUP: 'node', + // element that contains visual elements (like rect, ellipse). + SHAPE: 'nodeshape', + // <*> element(s) under SHAPE that should receive color updates. + COLOR_TARGET: 'nodecolortarget', + // element showing the node's label. + LABEL: 'nodelabel', + // element that contains all visuals for the expand/collapse + // button for expandable group nodes. + BUTTON_CONTAINER: 'buttoncontainer', + // element that surrounds expand/collapse buttons. + BUTTON_CIRCLE: 'buttoncircle', + // element of the expand button. + EXPAND_BUTTON: 'expandbutton', + // element of the collapse button. + COLLAPSE_BUTTON: 'collapsebutton' + }, + Edge: { + CONTAINER: 'edges', + GROUP: 'edge', + LINE: 'edgeline', + REF_LINE: 'refline', + STRUCTURAL: 'structural' + }, + Annotation: { + OUTBOX: 'out-annotations', + INBOX: 'in-annotations', + GROUP: 'annotation', + NODE: 'annotation-node', + EDGE: 'annotation-edge', + CONTROL_EDGE: 'annotation-control-edge', + LABEL: 'annotation-label', + ELLIPSIS: 'annotation-ellipsis' + }, + Scene: { + GROUP: 'scene', + CORE: 'core', + INEXTRACT: 'in-extract', + OUTEXTRACT: 'out-extract' + }, + Subscene: {GROUP: 'subscene'}, + OPNODE: 'op', + METANODE: 'meta', + SERIESNODE: 'series', + BRIDGENODE: 'bridge', + ELLIPSISNODE: 'ellipsis' + }; + + /** + * A health pill encapsulates an overview of tensor element values. The value + * field is a list of 12 numbers that shed light on the status of the tensor. + * Visualized in health pills are the 3rd through 8th (inclusive) numbers of + * health pill values. Those 6 numbers are counts of tensor elements that fall + * under -Inf, negative, 0, positive, +Inf, NaN (in that order). + * + * Please keep this interface consistent with HealthPillDatum within + * backend.ts. + */ + export interface HealthPill { + node_name: string; + output_slot: number; + value: number[]; + wall_time: number; + step: number; + } + ; + + /** + * Encapsulates how to render a single entry in a health pill. Each entry + * corresponds to a category of tensor element values. + */ + export interface HealthPillEntry { + background_color: string; + label: string; + } + ; + export let healthPillEntries: HealthPillEntry[] = [ + { + background_color: '#CC2F2C', + label: 'NaN', + }, + { + background_color: '#FF8D00', + label: '- ∞', + }, + { + background_color: '#EAEAEA', + label: '-', + }, + { + background_color: '#A5A5A5', + label: '0', + }, + { + background_color: '#262626', + label: '+', + }, + { + background_color: '#003ED4', + label: '+ ∞', + }, + ]; + + /** + * Helper method for fitting the graph in the svg view. + * + * @param svg The main svg. + * @param zoomG The svg group used for panning and zooming. + * @param d3zoom The zoom behavior. + * @param callback Called when the fitting is done. + */ + export function fit(svg, zoomG, d3zoom, callback) { + let svgRect = svg.getBoundingClientRect(); + let sceneSize = null; + try { + sceneSize = zoomG.getBBox(); + if (sceneSize.width === 0) { + // There is no scene anymore. We have been detached from the dom. + return; + } + } catch (e) { + // Firefox produced NS_ERROR_FAILURE if we have been + // detached from the dom. + return; + } + let scale = 0.9 * + Math.min( + svgRect.width / sceneSize.width, svgRect.height / sceneSize.height, + 2); + let params = layout.PARAMS.graph; + const transform = d3.zoomIdentity + .scale(scale) + .translate(params.padding.paddingLeft, params.padding.paddingTop); + + d3.select(zoomG) + .transition() + .duration(500) + .call(d3zoom.transform, transform) + .on('end.fitted', () => { + // Remove the listener for the zoomend event, + // so we don't get called at the end of regular zoom events, + // just those that fit the graph to screen. + d3zoom.on('end.fitted', null); + callback(); + }); +}; + +/** + * Helper method for panning the graph to center on the provided node, + * if the node is currently off-screen. + * + * @param nodeName The node to center the graph on + * @param svg The root SVG element for the graph + * @param zoomG The svg group used for panning and zooming. + * @param d3zoom The zoom behavior. + * @return True if the graph had to be panned to display the + * provided node. + */ +export function panToNode(nodeName: String, svg, zoomG, d3zoom): boolean { + let node = d3 + .select('[data-name="' + nodeName + '"].' + Class.Node.GROUP) + .node(); + if (!node) { + return false; + } + let translate = d3zoom.translate(); + // Check if the selected node is off-screen in either + // X or Y dimension in either direction. + let nodeBox = node.getBBox(); + let nodeCtm = node.getScreenCTM(); + let pointTL = svg.createSVGPoint(); + let pointBR = svg.createSVGPoint(); + pointTL.x = nodeBox.x; + pointTL.y = nodeBox.y; + pointBR.x = nodeBox.x + nodeBox.width; + pointBR.y = nodeBox.y + nodeBox.height; + pointTL = pointTL.matrixTransform(nodeCtm); + pointBR = pointBR.matrixTransform(nodeCtm); + let isOutsideOfBounds = (start, end, bound) => { + return end < 0 || start > bound; + }; + let svgRect = svg.getBoundingClientRect(); + if (isOutsideOfBounds(pointTL.x, pointBR.x, svgRect.width) || + isOutsideOfBounds(pointTL.y, pointBR.y, svgRect.height)) { + // Determine the amount to transform the graph in both X and Y + // dimensions in order to center the selected node. This takes into + // acount the position of the node, the size of the svg scene, the + // amount the scene has been scaled by through zooming, and any previous + // transform already performed by this logic. + let centerX = (pointTL.x + pointBR.x) / 2; + let centerY = (pointTL.y + pointBR.y) / 2; + let dx = ((svgRect.width / 2) - centerX); + let dy = ((svgRect.height / 2) - centerY); + let zoomEvent = d3zoom.translate([translate[0] + dx, translate[1] + dy]) + .event; + d3.select(zoomG).transition().duration(500).call(zoomEvent); + return true; + } + return false; +}; + +/** + * Given a container d3 selection, select a child svg element of a given tag + * and class if exists or append / insert one otherwise. If multiple children + * matches the tag and class name, returns only the first one. + * + * @param container + * @param tagName tag name. + * @param className (optional) Class name or a list of class names. + * @param before (optional) reference DOM node for insertion. + * @return selection of the element + */ +export function selectOrCreateChild( + container, tagName: string, className?: string | string[], before?) { + let child = selectChild(container, tagName, className); + if (!child.empty()) { + return child; + } + let newElement = + document.createElementNS('http://www.w3.org/2000/svg', tagName); + + if (className instanceof Array) { + for (let i = 0; i < className.length; i++) { + newElement.classList.add(className[i]); + } + } else { + newElement.classList.add(className); + } + + if (before) { // if before exists, insert + container.node().insertBefore(newElement, before); + } else { // otherwise, append + container.node().appendChild(newElement); + } + return d3.select(newElement) + // need to bind data to emulate d3_selection.append + .datum(container.datum()); +}; + +/** + * Given a container d3 selection, select a child element of a given tag and + * class. If multiple children matches the tag and class name, returns only + * the first one. + * + * @param container + * @param tagName tag name. + * @param className (optional) Class name or list of class names. + * @return selection of the element, or an empty selection + */ +export function selectChild( + container, tagName: string, className?: string | string[]) { + let children = container.node().childNodes; + for (let i = 0; i < children.length; i++) { + let child = children[i]; + if (child.tagName === tagName) { + if (className instanceof Array) { + let hasAllClasses = true; + for (let j = 0; j < className.length; j++) { + hasAllClasses = + hasAllClasses && child.classList.contains(className[j]); + } + if (hasAllClasses) { + return d3.select(child); + } + } else if ((!className || child.classList.contains(className))) { + return d3.select(child); + } + } + } + return d3.select(null); +}; + +/** + * Select or create a sceneGroup and build/update its nodes and edges. + * + * Structure Pattern: + * + * + * + * + * ... stuff from tf.graph.scene.edges.build ... + * + * + * ... stuff from tf.graph.scene.nodes.build ... + * + * + * + * + * ... stuff from tf.graph.scene.nodes.build ... + * + * + * + * + * ... stuff from tf.graph.scene.nodes.build ... + * + * + * + * + * @param container D3 selection of the parent. + * @param renderNode render node of a metanode or series node. + * @param sceneElement polymer element. + * @param sceneClass class attribute of the scene (default='scene'). + */ +export function buildGroup(container, + renderNode: render.RenderGroupNodeInfo, + sceneElement, + sceneClass: string) { + sceneClass = sceneClass || Class.Scene.GROUP; + let isNewSceneGroup = selectChild(container, 'g', sceneClass).empty(); + let sceneGroup = selectOrCreateChild(container, 'g', sceneClass); + + // core + let coreGroup = selectOrCreateChild(sceneGroup, 'g', Class.Scene.CORE); + let coreNodes = _.reduce(renderNode.coreGraph.nodes(), (nodes, name) => { + let node = renderNode.coreGraph.node(name); + if (!node.excluded) { + nodes.push(node); + } + return nodes; + }, []); + + if (renderNode.node.type === NodeType.SERIES) { + // For series, we want the first item on top, so reverse the array so + // the first item in the series becomes last item in the top, and thus + // is rendered on the top. + coreNodes.reverse(); + } + + // Create the layer of edges for this scene (paths). + edge.buildGroup(coreGroup, renderNode.coreGraph, sceneElement); + + // Create the layer of nodes for this scene (ellipses, rects etc). + node.buildGroup(coreGroup, coreNodes, sceneElement); + + // In-extract + if (renderNode.isolatedInExtract.length > 0) { + let inExtractGroup = + selectOrCreateChild(sceneGroup, 'g', Class.Scene.INEXTRACT); + node.buildGroup(inExtractGroup, renderNode.isolatedInExtract, + sceneElement); + } else { + selectChild(sceneGroup, 'g', Class.Scene.INEXTRACT).remove(); + } + + // Out-extract + if (renderNode.isolatedOutExtract.length > 0) { + let outExtractGroup = + selectOrCreateChild(sceneGroup, 'g', Class.Scene.OUTEXTRACT); + node.buildGroup(outExtractGroup, renderNode.isolatedOutExtract, + sceneElement); + } else { + selectChild(sceneGroup, 'g', Class.Scene.OUTEXTRACT).remove(); + } + + position(sceneGroup, renderNode); + + // Fade in the scene group if it didn't already exist. + if (isNewSceneGroup) { + sceneGroup.attr('opacity', 0).transition().attr('opacity', 1); + } + + return sceneGroup; +}; + +/** + * Given a scene's svg group, set g.in-extract, g.coreGraph, g.out-extract svg + * groups' position relative to the scene. + * + * @param sceneGroup + * @param renderNode render node of a metanode or series node. + */ +function position(sceneGroup, renderNode: render.RenderGroupNodeInfo) { + // Translate scenes down by the label height so that when showing graphs in + // expanded metanodes, the graphs are below the labels. Do not shift them + // down for series nodes as series nodes don't have labels inside of their + // bounding boxes. + let yTranslate = renderNode.node.type === NodeType.SERIES ? + 0 : layout.PARAMS.subscene.meta.labelHeight; + + // core + translate(selectChild(sceneGroup, 'g', Class.Scene.CORE), 0, yTranslate); + + // in-extract + let hasInExtract = renderNode.isolatedInExtract.length > 0; + let hasOutExtract = renderNode.isolatedOutExtract.length > 0; + + if (hasInExtract) { + let offset = layout.PARAMS.subscene.meta.extractXOffset; + let inExtractX = renderNode.coreBox.width - + renderNode.inExtractBox.width / 2 - renderNode.outExtractBox.width - + (hasOutExtract ? offset : 0); + translate( + selectChild(sceneGroup, 'g', Class.Scene.INEXTRACT), inExtractX, + yTranslate); + } + + // out-extract + if (hasOutExtract) { + let outExtractX = renderNode.coreBox.width - + renderNode.outExtractBox.width / 2; + translate( + selectChild(sceneGroup, 'g', Class.Scene.OUTEXTRACT), outExtractX, + yTranslate); + } +}; + +/** Adds a click listener to a group that fires a graph-select event */ +export function addGraphClickListener(graphGroup, sceneElement) { + d3.select(graphGroup).on('click', () => { + sceneElement.fire('graph-select'); + }); +}; + +/** Helper for adding transform: translate(x0, y0) */ +export function translate(selection, x0: number, y0: number) { + // If it is already placed on the screen, make it a transition. + if (selection.attr('transform') != null) { + selection = selection.transition('position'); + } + selection.attr('transform', 'translate(' + x0 + ',' + y0 + ')'); +}; + +/** + * Helper for setting position of a svg rect + * @param rect rect to set position of. + * @param cx Center x. + * @param cy Center x. + * @param width Width to set. + * @param height Height to set. + */ +export function positionRect(rect, cx: number, cy: number, width: number, + height: number) { + rect.transition().attr({ + x: cx - width / 2, + y: cy - height / 2, + width: width, + height: height + }); +}; + +/** + * Helper for setting position of a svg expand/collapse button + * @param button container group + * @param renderNode the render node of the group node to position + * the button on. + */ +export function positionButton(button, renderNode: render.RenderNodeInfo) { + let cx = layout.computeCXPositionOfNodeShape(renderNode); + // Position the button in the top-right corner of the group node, + // with space given the draw the button inside of the corner. + let width = renderNode.expanded ? + renderNode.width : renderNode.coreBox.width; + let height = renderNode.expanded ? + renderNode.height : renderNode.coreBox.height; + let x = cx + width / 2 - 6; + let y = renderNode.y - height / 2 + 6; + // For unexpanded series nodes, the button has special placement due + // to the unique visuals of this group node. + if (renderNode.node.type === NodeType.SERIES && !renderNode.expanded) { + x += 10; + y -= 2; + } + let translateStr = 'translate(' + x + ',' + y + ')'; + button.selectAll('path').transition().attr('transform', translateStr); + button.select('circle').transition().attr( + {cx: x, cy: y, r: layout.PARAMS.nodeSize.meta.expandButtonRadius}); +}; + +/** + * Helper for setting position of a svg ellipse + * @param ellipse ellipse to set position of. + * @param cx Center x. + * @param cy Center x. + * @param width Width to set. + * @param height Height to set. + */ +export function positionEllipse(ellipse, cx: number, cy: number, + width: number, height: number) { + ellipse.transition().attr({ + cx: cx, + cy: cy, + rx: width / 2, + ry: height / 2 + }); +}; + +/** + * @param {number} stat A stat for a health pill (such as mean or variance). + * @param {boolean} shouldRoundOnesDigit Whether to round this number to the + * ones digit. Useful for say int, uint, and bool output types. + * @return {string} A human-friendly string representation of that stat. + */ +export function humanizeHealthPillStat(stat, shouldRoundOnesDigit) { + if (shouldRoundOnesDigit) { + return stat.toFixed(0); + } + + if (Math.abs(stat) >= 1) { + return stat.toFixed(1); + } + return stat.toExponential(1); +} + +/** + * Renders a health pill for an op atop a node. + */ +function _addHealthPill( + nodeGroupElement: SVGElement, healthPill: HealthPill, + nodeInfo: render.RenderNodeInfo) { + // Check if text already exists at location. + d3.select(nodeGroupElement.parentNode as any).selectAll('.health-pill').remove(); + + if (!nodeInfo || !healthPill) { + return; + } + + let lastHealthPillData = healthPill.value; + + // For now, we only visualize the 6 values that summarize counts of tensor + // elements of various categories: -Inf, negative, 0, positive, Inf, and NaN. + let lastHealthPillOverview = lastHealthPillData.slice(2, 8); + let totalCount = lastHealthPillData[1]; + + let healthPillWidth = 60; + let healthPillHeight = 10; + if (nodeInfo.node.type === tf.graph.NodeType.OP) { + // Use a smaller health pill for op nodes (rendered as smaller ellipses). + healthPillWidth /= 2; + healthPillHeight /= 2; + } + + let healthPillGroup = document.createElementNS(svgNamespace, 'g'); + healthPillGroup.classList.add('health-pill'); + + // Define the gradient for the health pill. + let healthPillDefs = document.createElementNS(svgNamespace, 'defs'); + healthPillGroup.appendChild(healthPillDefs); + let healthPillGradient = + document.createElementNS(svgNamespace, 'linearGradient'); + const healthPillGradientId = 'health-pill-gradient'; + healthPillGradient.setAttribute('id', healthPillGradientId); + let titleOnHoverTextEntries = []; + let cumulativeCount = 0; + let previousOffset = '0%'; + for (let i = 0; i < lastHealthPillOverview.length; i++) { + if (!lastHealthPillOverview[i]) { + // Exclude empty categories. + continue; + } + cumulativeCount += lastHealthPillOverview[i]; + + // Create a color interval using 2 stop elements. + let stopElement0 = document.createElementNS(svgNamespace, 'stop'); + stopElement0.setAttribute('offset', previousOffset); + stopElement0.setAttribute( + 'stop-color', healthPillEntries[i].background_color); + healthPillGradient.appendChild(stopElement0); + + let stopElement1 = document.createElementNS(svgNamespace, 'stop'); + let percent = (cumulativeCount * 100 / totalCount) + '%'; + stopElement1.setAttribute('offset', percent); + stopElement1.setAttribute( + 'stop-color', healthPillEntries[i].background_color); + healthPillGradient.appendChild(stopElement1); + previousOffset = percent; + + // Include this number in the title that appears on hover. + titleOnHoverTextEntries.push( + healthPillEntries[i].label + ': ' + lastHealthPillOverview[i]); + } + healthPillDefs.appendChild(healthPillGradient); + + // Create the rectangle for the health pill. + let rect = document.createElementNS(svgNamespace, 'rect'); + rect.setAttribute('fill', 'url(#' + healthPillGradientId + ')'); + rect.setAttribute('width', String(healthPillWidth)); + rect.setAttribute('height', String(healthPillHeight)); + healthPillGroup.appendChild(rect); + + // Show a title with specific counts on hover. + let titleSvg = document.createElementNS(svgNamespace, 'title'); + titleSvg.textContent = titleOnHoverTextEntries.join(', '); + healthPillGroup.appendChild(titleSvg); + + // Center this health pill just right above the node for the op. + let healthPillX = nodeInfo.x - healthPillWidth / 2; + let healthPillY = nodeInfo.y - healthPillHeight - nodeInfo.height / 2 - 2; + if (nodeInfo.labelOffset < 0) { + // The label is positioned above the node. Do not occlude the label. + healthPillY += nodeInfo.labelOffset; + } + + if (lastHealthPillOverview[2] || lastHealthPillOverview[3] || + lastHealthPillOverview[4]) { + // At least 1 "non-Inf and non-NaN" value exists (a -, 0, or + value). Show + // stats on tensor values. + + // Determine if we should display the output range as integers. + let shouldRoundOnesDigit = false; + let node = nodeInfo.node as OpNode; + let attributes = node.attr; + if (attributes && attributes.length) { + // Find the attribute for output type if there is one. + for (let i = 0; i < attributes.length; i++) { + if (attributes[i].key === 'T') { + // Note whether the output type is an integer. + let outputType = attributes[i].value['type']; + shouldRoundOnesDigit = + outputType && /^DT_(BOOL|INT|UINT)/.test(outputType); + break; + } + } + } + + let statsSvg = document.createElementNS(svgNamespace, 'text'); + const minString = + humanizeHealthPillStat(lastHealthPillData[8], shouldRoundOnesDigit); + const maxString = + humanizeHealthPillStat(lastHealthPillData[9], shouldRoundOnesDigit); + statsSvg.textContent = minString + ' ~ ' + maxString; + statsSvg.classList.add('health-pill-stats'); + statsSvg.setAttribute('x', String(healthPillWidth / 2)); + statsSvg.setAttribute('y', '-2'); + healthPillGroup.appendChild(statsSvg); + } + + healthPillGroup.setAttribute( + 'transform', 'translate(' + healthPillX + ', ' + healthPillY + ')'); + + Polymer.dom(nodeGroupElement.parentNode).appendChild(healthPillGroup); +} + +/** + * Adds health pills (which visualize tensor summaries) to a graph group. + * @param svgRoot The root SVG element of the graph to add heath pills to. + * @param nodeNamesToHealthPills An object mapping node name to health pill. + * @param colors A list of colors to use. + */ +export function addHealthPills( + svgRoot: SVGElement, nodeNamesToHealthPills: {[key: string]: HealthPill[]}, + healthPillStepIndex: number) { + if (!nodeNamesToHealthPills) { + // No health pill information available. + return; + } + + let svgRootSelection = d3.select(svgRoot); + svgRootSelection.selectAll('g.nodeshape') + .each(function(nodeInfo: render.RenderNodeInfo) { + // Only show health pill data for this node if it is available. + let healthPills = nodeNamesToHealthPills[nodeInfo.node.name]; + let healthPill = healthPills ? healthPills[healthPillStepIndex] : null; + _addHealthPill((this as SVGElement), healthPill, nodeInfo); + }); +}; + +} // close module diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/template.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/template.ts new file mode 100644 index 0000000000..7800d46029 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/template.ts @@ -0,0 +1,305 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +module tf.graph.template { + +/** + * Detect repeating patterns of subgraphs. + * Assign templateId to each subgraph if it belongs to a template. + * Returns clusters of similar subgraphs . + * + * @param graph + * @param verifyTemplate whether to run the template verification algorithm + * @return a dict (template id => Array of node names) + */ +export function detect(h, verifyTemplate): {[templateId: string]: string[]} { + // In any particular subgraph, there are either + // - leaf nodes (which do not have subgraph) + // - metanode nodes - some of them have only one member (singular metanode) + // and some have multiple members (non-singular metanode) + + // First, generate a nearest neighbor hash of metanode nodes. + let nnGroups = clusterSimilarSubgraphs(h); + + // For each metanode, compare its subgraph (starting from shallower groups) + // and assign template id. + let templates = groupTemplateAndAssignId(nnGroups, verifyTemplate); + + // Sort the templates by minimum level in the graph at which they appear, + // as this leads to optimal setting of the colors of each template for + // maximum differentiation. + return <{[templateId: string]: string[]}>_(templates) + .pairs() + .sortBy(function(pair: {level: number, nodes: string[]}[]) { + return pair[1].level; + }) + .map(function(pair: {level: number, nodes: string[]}[]) { + return [pair[0], pair[1].nodes]; + }) + .object() + .value(); +}; + +/** + * @return Unique string for a metanode based on depth, |V|, |E| and + * op type histogram. + */ +function getSignature(metanode) { + // depth= |V|= |E|= + let props = _.map( + { + 'depth': metanode.depth, + '|V|': metanode.metagraph.nodes().length, + '|E|': metanode.metagraph.edges().length + }, + function(v, k) { return k + '=' + v; }) + .join(' '); + + // optype1=count1,optype2=count2 + let ops = _.map(metanode.opHistogram, function(count, op) { + return op + '=' + count; + }).join(','); + + return props + ' [ops] ' + ops; +} + +/** + * Generate a nearest neighbor hash of metanodes + * based on depth, |V|, |E|, and opHistogram of their subgraph + * (excluding leaf nodes and singular metanodes). + * @param graph The graph + * @return Array of pairs of [signature, + * Object with min level of the template and an Array of tf.graph.Group] + * sort by ascending order of minimum depth at which metanode appears. + */ +function clusterSimilarSubgraphs(h: hierarchy.Hierarchy) { + /** a dict from metanode.signature() => Array of tf.graph.Groups */ + let hashDict = _(h.getNodeMap()).reduce( + (hash, node: OpNode|Metanode, name) => { + if (node.type !== NodeType.META) { + return hash; + } + let levelOfMetaNode = name.split('/').length - 1; + let signature = getSignature(node); + let templateInfo = hash[signature] || + {nodes: [], level: levelOfMetaNode}; + hash[signature] = templateInfo; + templateInfo.nodes.push(node); + if (templateInfo.level > levelOfMetaNode) { + templateInfo.level = levelOfMetaNode; + } + return hash; + }, {}); + + return _(hashDict) + .pairs() + // filter nn metanode with only one member + .filter(function(pair: {level: number, nodes: string[]}) { + return pair[1].nodes.length > 1; + }) + .sortBy(function(pair: {level: number, nodes: string[]}) { + // sort by depth + // (all members in the same nnGroup has equal depth) + return pair[1].nodes[0].depth; + }) + .value(); +} + +function groupTemplateAndAssignId(nnGroups, verifyTemplate) { + // For each metanode, compare its subgraph (starting from shallower groups) + // and assign template id. + let result: {[templateId: string]: {level: number, nodes: string[]}} = {}; + return _.reduce(nnGroups, function(templates, nnGroupPair) { + let signature = nnGroupPair[0], + nnGroup = nnGroupPair[1].nodes, + clusters = []; + + nnGroup.forEach(function(metanode) { + // check with each existing cluster + for (let i = 0; i < clusters.length; i++) { + let similar = !verifyTemplate || + isSimilarSubgraph( + clusters[i].metanode.metagraph, + metanode.metagraph + ); + // if similar, just add this metanode to the cluster + if (similar) { + // get template from the first one + metanode.templateId = clusters[i].metanode.templateId; + clusters[i].members.push(metanode.name); + return; + } + } + // otherwise create a new cluster with id 'signature [count] ' + metanode.templateId = signature + '[' + clusters.length + ']'; + clusters.push({ + metanode: metanode, + members: [metanode.name] + }); + }); + + clusters.forEach(function(c) { + templates[c.metanode.templateId] = { + level: nnGroupPair[1].level, + nodes: c.members + }; + }); + return templates; + }, result); +} + +function sortNodes(names: string[], + graph: graphlib.Graph, prefix: string) { + return _.sortByAll(names, + function(name) { + let node = graph.node(name); + return (node).op; + }, + function(name) { + let node = graph.node(name); + return (node).templateId; + }, + function(name) { + return graph.neighbors(name).length; + }, + function(name) { + return graph.predecessors(name).length; + }, + function(name) { + return graph.successors(name).length; + }, + function(name) { + return name.substr(prefix.length); + }); +} + +function isSimilarSubgraph(g1: graphlib.Graph, + g2: graphlib.Graph) { + if (!tf.graph.hasSimilarDegreeSequence(g1, g2)) { + return false; + } + + // if we want to skip, just return true here. + // return true; + + // Verify sequence by running DFS + let g1prefix = g1.graph().name; + let g2prefix = g2.graph().name; + + let visited1 = {}; + let visited2 = {}; + let stack = []; + + /** + * push sources or successors into the stack + * if the visiting pattern has been similar. + */ + function stackPushIfNotDifferent(n1, n2) { + let sub1 = n1.substr(g1prefix.length), + sub2 = n2.substr(g2prefix.length); + + /* tslint:disable */ + if (visited1[sub1] ^ visited2[sub1]) { + console.warn( + 'different visit pattern', '[' + g1prefix + ']', sub1, + '[' + g2prefix + ']', sub2); + return true; + } + /* tslint:enable */ + if (!visited1[sub1]) { // implied && !visited2[sub2] + visited1[sub1] = visited2[sub2] = true; + stack.push({n1: n1, n2: n2}); + } + + return false; + } + + // check if have same # of sources then sort and push + let sources1 = g1.sources(); + let sources2 = g2.sources(); + if (sources1.length !== sources2.length) { + /* tslint:disable */ + console.log('different source length'); + /* tslint:enable */ + return false; + } + sources1 = sortNodes(sources1, g1, g1prefix); + sources2 = sortNodes(sources2, g2, g2prefix); + + for (let i = 0; i < sources1.length; i++) { + let different = stackPushIfNotDifferent(sources1[i], sources2[i]); + if (different) { + return false; + } + } + + while (stack.length > 0) { + let cur = stack.pop(); + + // check node + let similar = isSimilarNode(g1.node(cur.n1), g2.node(cur.n2)); + if (!similar) { + return false; + } + + // check if have same # of successors then sort and push + let succ1 = g1.successors(cur.n1), succ2 = g2.successors(cur.n2); + if (succ1.length !== succ2.length) { + /* tslint:disable */ + console.log('# of successors mismatch', succ1, succ2); + /* tslint:enable */ + return false; + } + succ1 = sortNodes(succ1, g1, g1prefix); + succ2 = sortNodes(succ2, g2, g2prefix); + + for (let j = 0; j < succ1.length; j++) { + let different = stackPushIfNotDifferent(succ1[j], succ2[j]); + if (different) { + return false; + } + } + } + + return true; +} + +/** + * Returns if two nodes have identical structure. + */ +function isSimilarNode(n1: OpNode|Metanode|SeriesNode, + n2: OpNode|Metanode|SeriesNode): boolean { + if (n1.type === NodeType.META) { + // compare metanode + let metanode1 = n1; + let metanode2 = n2; + return metanode1.templateId && metanode2.templateId && + metanode1.templateId === metanode2.templateId; + } else if (n1.type === NodeType.OP && n2.type === NodeType.OP) { + // compare leaf node + return (n1).op === (n2).op; + } else if (n1.type === NodeType.SERIES && n2.type === NodeType.SERIES) { + // compare series node sizes and operations + // (only need to check one op as all op nodes are identical in series) + let sn1 = n1; + let sn2 = n2; + let seriesnode1Count = sn1.metagraph.nodeCount(); + return (seriesnode1Count === sn2.metagraph.nodeCount() && + (seriesnode1Count === 0 || + ((sn1.metagraph.node(sn1.metagraph.nodes()[0])).op === + (sn2.metagraph.node(sn2.metagraph.nodes()[0])).op))); + } + return false; +} +} diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/graph-test.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/graph-test.ts new file mode 100644 index 0000000000..af3030197e --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/graph-test.ts @@ -0,0 +1,103 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +suite('graph', () => { + let assert = chai.assert; + + test('graphlib exists', () => { assert.isTrue(graphlib != null); }); + + test('simple graph contruction', done => { + let pbtxt = tf.graph.test.util.stringToArrayBuffer(` + node { + name: "Q" + op: "Input" + } + node { + name: "W" + op: "Input" + } + node { + name: "X" + op: "MatMul" + input: "Q:2" + input: "W" + }`); + let statsPbtxt = tf.graph.test.util.stringToArrayBuffer(`step_stats { + dev_stats { + device: "cpu" + node_stats { + node_name: "Q" + all_start_micros: 10 + all_end_rel_micros: 4 + } + node_stats { + node_name: "Q" + all_start_micros: 12 + all_end_rel_micros: 4 + } + } + }`); + + let buildParams: tf.graph.BuildParams = { + enableEmbedding: true, + inEmbeddingTypes: ['Const'], + outEmbeddingTypes: ['^[a-zA-Z]+Summary$'], + refEdges: {} + }; + let dummyTracker = + tf.graph.util.getTracker({set: () => { return; }, progress: 0}); + tf.graph.parser.parseGraphPbTxt(pbtxt).then(nodes => { + tf.graph.build(nodes, buildParams, dummyTracker) + .then((slimGraph: tf.graph.SlimGraph) => { + assert.isTrue(slimGraph.nodes['X'] != null); + assert.isTrue(slimGraph.nodes['W'] != null); + assert.isTrue(slimGraph.nodes['Q'] != null); + + let firstInputOfX = slimGraph.nodes['X'].inputs[0]; + assert.equal(firstInputOfX.name, 'Q'); + assert.equal(firstInputOfX.outputTensorIndex, 2); + + let secondInputOfX = slimGraph.nodes['X'].inputs[1]; + assert.equal(secondInputOfX.name, 'W'); + assert.equal(secondInputOfX.outputTensorIndex, 0); + + tf.graph.parser.parseStatsPbTxt(statsPbtxt).then(stepStats => { + tf.graph.joinStatsInfoWithGraph(slimGraph, stepStats); + assert.equal(slimGraph.nodes['Q'].stats.getTotalMicros(), 6); + done(); + }); + }); + }); + }); + + test('health pill numbers round correctly', () => { + // Integers are rounded to the ones place. + assert.equal(tf.graph.scene.humanizeHealthPillStat(42.0, true), '42'); + + // Numbers with magnitude >= 1 are rounded to the tenths place. + assert.equal(tf.graph.scene.humanizeHealthPillStat(1, false), '1.0'); + assert.equal(tf.graph.scene.humanizeHealthPillStat(42.42, false), '42.4'); + assert.equal(tf.graph.scene.humanizeHealthPillStat(-42.42, false), '-42.4'); + + // Numbers with magnitude < 1 are written in scientific notation rounded to + // the tenths place. + assert.equal(tf.graph.scene.humanizeHealthPillStat(0, false), '0.0e+0'); + assert.equal(tf.graph.scene.humanizeHealthPillStat(0.42, false), '4.2e-1'); + assert.equal( + tf.graph.scene.humanizeHealthPillStat(-0.042, false), '-4.2e-2'); + }); + + // TODO(bp): write tests. +}); diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/hierarchy-test.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/hierarchy-test.ts new file mode 100644 index 0000000000..fa62ffe2c7 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/hierarchy-test.ts @@ -0,0 +1,23 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +suite('graph', () => { + let assert = chai.assert; + + test('graphlib exists', () => { assert.isTrue(graphlib != null); }); + + // TODO(bp): write tests. + +}); diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/index.html b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/index.html new file mode 100644 index 0000000000..7564167129 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/index.html @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/layout-test.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/layout-test.ts new file mode 100644 index 0000000000..b4884413c9 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/layout-test.ts @@ -0,0 +1,23 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +suite('layout', () => { + let assert = chai.assert; + + test('dagre exists', () => { assert.isTrue(dagre != null); }); + + // TODO(bp): write tests. + +}); diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/parser-test.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/parser-test.ts new file mode 100644 index 0000000000..7c73178c1c --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/parser-test.ts @@ -0,0 +1,83 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +suite('parser', () => { + let assert = chai.assert; + + test('simple pbtxt', done => { + let pbtxt = tf.graph.test.util.stringToArrayBuffer(`node { + name: "Q" + op: "Input" + } + node { + name: "W" + op: "Input" + } + node { + name: "X" + op: "MatMul" + input: "Q" + input: "W" + }`); + tf.graph.parser.parseGraphPbTxt(pbtxt).then(nodes => { + assert.isTrue(nodes != null && nodes.length === 3); + + assert.equal('Q', nodes[0].name); + assert.equal('Input', nodes[0].op); + + assert.equal('W', nodes[1].name); + assert.equal('Input', nodes[1].op); + + assert.equal('X', nodes[2].name); + assert.equal('MatMul', nodes[2].op); + assert.equal('Q', nodes[2].input[0]); + assert.equal('W', nodes[2].input[1]); + + done(); + }); + }); + + test('stats pbtxt parsing', done => { + let statsPbtxt = tf.graph.test.util.stringToArrayBuffer(`step_stats { + dev_stats { + device: "cpu" + node_stats { + node_name: "Q" + all_start_micros: 10 + all_end_rel_micros: 4 + } + node_stats { + node_name: "Q" + all_start_micros: 12 + all_end_rel_micros: 4 + } + } + }`); + tf.graph.parser.parseStatsPbTxt(statsPbtxt).then(stepStats => { + assert.equal(stepStats.dev_stats.length, 1); + assert.equal(stepStats.dev_stats[0].device, 'cpu'); + assert.equal(stepStats.dev_stats[0].node_stats.length, 2); + assert.equal(stepStats.dev_stats[0].node_stats[0].all_start_micros, 10); + assert.equal(stepStats.dev_stats[0].node_stats[1].node_name, 'Q'); + assert.equal(stepStats.dev_stats[0].node_stats[1].all_end_rel_micros, 4); + done(); + }); + }); + + test('d3 exists', () => { assert.isTrue(d3 != null); }); + + // TODO(nsthorat): write tests. + +}); diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/util-test.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/util-test.ts new file mode 100644 index 0000000000..4535d24888 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/util-test.ts @@ -0,0 +1,56 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +suite('util', () => { + let assert = chai.assert; + + test('remove common prefix', () => { + + // Empty array. + let result = tf.graph.util.removeCommonPrefix([]); + assert.deepEqual(result, []); + + // No common prefix. + result = tf.graph.util.removeCommonPrefix(['a', 'b', 'c']); + assert.deepEqual(result, ['a', 'b', 'c']); + + // One of the elements is empty string. + result = tf.graph.util.removeCommonPrefix(['a/b', '', 'a/c']); + assert.deepEqual(result, ['a/b', '', 'a/c']); + + // Only one string. + result = tf.graph.util.removeCommonPrefix(['a/b/c']); + assert.deepEqual(result, ['a/b/c']); + + // `q/w/` is the common prefix. Expect `q/w/` to be removed. + result = tf.graph.util.removeCommonPrefix(['q/w/a', 'q/w/b', 'q/w/c/f']); + assert.deepEqual(result, ['a', 'b', 'c/f']); + + // `q/w/` is the common prefix and also an element. Expect nothing to be + // removed since the common prefix is also an element in the array. + result = tf.graph.util.removeCommonPrefix(['q/w/', 'q/w/b', 'q/w/c/f']); + assert.deepEqual(result, ['q/w/', 'q/w/b', 'q/w/c/f']); + }); + + test('query params', () => { + // Starts with question mark. + let queryParams = tf.graph.util.getQueryParams('?foo=1&bar=2'); + assert.deepEqual(queryParams, {'foo': '1', 'bar': '2'}); + + // No question mark. + queryParams = tf.graph.util.getQueryParams('foo=1&bar=2'); + assert.deepEqual(queryParams, {'foo': '1', 'bar': '2'}); + }); +}); diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/util.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/util.ts new file mode 100644 index 0000000000..bc73b735ed --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/test/util.ts @@ -0,0 +1,31 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +/* tslint:disable:no-namespace */ +module tf.graph.test.util { + /** + * Converts a utf-8 string to an ArrayBuffer. + */ + export function stringToArrayBuffer(str): ArrayBuffer { + let buf = new ArrayBuffer(str.length); + let bufView = new Uint8Array(buf); + for (let i = 0, strLen = str.length; i < strLen; i++) { + bufView[i] = str.charCodeAt(i); + } + return buf; + } + +} // module diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/tf-graph-common.html b/tensorflow/tensorboard/components/tf_graph_common_d3v4/tf-graph-common.html new file mode 100644 index 0000000000..a460072a38 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/tf-graph-common.html @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/util.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/util.ts new file mode 100644 index 0000000000..7f4d329e79 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/util.ts @@ -0,0 +1,291 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** + * @fileoverview Utility functions for the tensorflow graph visualizer. + */ + +module tf.graph.util { + /** + * Recommended delay (ms) when running an expensive task asynchronously + * that gives enough time for the progress bar to update its UI. + */ + const ASYNC_TASK_DELAY = 20; + + export function time(msg: string, task: () => T) { + let start = Date.now(); + let result = task(); + /* tslint:disable */ + console.log(msg, ':', Date.now() - start, 'ms'); + /* tslint:enable */ + return result; + } + + /** + * Creates a tracker that sets the progress property of the + * provided polymer component. The provided component must have + * a property called 'progress' that is not read-only. The progress + * property is an object with a numerical 'value' property and a + * string 'msg' property. + */ + export function getTracker(polymerComponent: any) { + return { + setMessage: function(msg) { + polymerComponent.set( + 'progress', {value: polymerComponent.progress.value, msg: msg}); + }, + updateProgress: function(value) { + polymerComponent.set('progress', { + value: polymerComponent.progress.value + value, + msg: polymerComponent.progress.msg + }); + }, + reportError: function(msg: string, err) { + // Log the stack trace in the console. + console.error(err.stack); + // And send a user-friendly message to the UI. + polymerComponent.set( + 'progress', + {value: polymerComponent.progress.value, msg: msg, error: true}); + }, + }; + } + + /** + * Creates a tracker for a subtask given the parent tracker, the total + * progress + * of the subtask and the subtask message. The parent task should pass a + * subtracker to its subtasks. The subtask reports its own progress which + * becames relative to the main task. + */ + export function getSubtaskTracker( + parentTracker: ProgressTracker, impactOnTotalProgress: number, + subtaskMsg: string): ProgressTracker { + return { + setMessage: function(progressMsg) { + // The parent should show a concatenation of its message along with + // its subtask tracker message. + parentTracker.setMessage(subtaskMsg + ': ' + progressMsg); + }, + updateProgress: function(incrementValue) { + // Update the parent progress relative to the child progress. + // For example, if the sub-task progresses by 30%, and the impact on the + // total progress is 50%, then the task progresses by 30% * 50% = 15%. + parentTracker.updateProgress( + incrementValue * impactOnTotalProgress / 100); + }, + reportError: function(msg: string, err: Error) { + // The parent should show a concatenation of its message along with + // its subtask error message. + parentTracker.reportError(subtaskMsg + ': ' + msg, err); + } + }; + } + + /** + * Runs an expensive task and return the result. + */ + export function runTask( + msg: string, incProgressValue: number, task: () => T, + tracker: ProgressTracker): T { + // Update the progress message to say the current running task. + tracker.setMessage(msg); + // Run the expensive task with a delay that gives enough time for the + // UI to update. + try { + let result = tf.graph.util.time(msg, task); + // Update the progress value. + tracker.updateProgress(incProgressValue); + // Return the result to be used by other tasks. + return result; + } catch (e) { + // Errors that happen inside asynchronous tasks are + // reported to the tracker using a user-friendly message. + tracker.reportError('Failed ' + msg, e); + } + } + + /** + * Runs an expensive task asynchronously and returns a promise of the result. + */ + export function runAsyncTask( + msg: string, incProgressValue: number, task: () => T, + tracker: ProgressTracker): Promise { + return new Promise((resolve, reject) => { + // Update the progress message to say the current running task. + tracker.setMessage(msg); + // Run the expensive task with a delay that gives enough time for the + // UI to update. + setTimeout(function() { + try { + let result = tf.graph.util.time(msg, task); + // Update the progress value. + tracker.updateProgress(incProgressValue); + // Return the result to be used by other tasks. + resolve(result); + } catch (e) { + // Errors that happen inside asynchronous tasks are + // reported to the tracker using a user-friendly message. + tracker.reportError('Failed ' + msg, e); + } + }, ASYNC_TASK_DELAY); + }); + } + + /** + * Asynchronously runs an expensive task that returns a promise. Updates the + * tracker's progress after the promise resolves. Returns a new promise that + * resolves after the progress is updated. + */ + export function runAsyncPromiseTask( + msg: string, incProgressValue: number, task: () => Promise, + tracker: ProgressTracker): Promise { + return new Promise((resolve, reject) => { + let handleError = function(e) { + // Errors that happen inside asynchronous tasks are + // reported to the tracker using a user-friendly message. + tracker.reportError('Failed ' + msg, e); + reject(e); + }; + + // Update the progress message to say the current running task. + tracker.setMessage(msg); + // Run the expensive task with a delay that gives enough time for the + // UI to update. + setTimeout(function() { + try { + let start = Date.now(); + task() + .then(function(value) { + /* tslint:disable */ + console.log(msg, ':', Date.now() - start, 'ms'); + // Update the progress value. + tracker.updateProgress(incProgressValue); + // Return the result to be used by other tasks. + resolve(value); + }) + .catch(handleError); + } catch (e) { + handleError(e); + } + }, ASYNC_TASK_DELAY); + }); + } + + /** + * Returns a query selector with escaped special characters that are not + * allowed in a query selector. + */ + export function escapeQuerySelector(querySelector: string): string { + return querySelector.replace(/([:.\[\],/\\\(\)])/g, '\\$1'); + } + + // For unit conversion. + export const MEMORY_UNITS = [ + // Atomic unit. + {symbol: 'B'}, + // numUnits specifies how many previous units this unit contains. + {symbol: 'KB', numUnits: 1024}, {symbol: 'MB', numUnits: 1024}, + {symbol: 'GB', numUnits: 1024}, {symbol: 'TB', numUnits: 1024}, + {symbol: 'PB', numUnits: 1024} + ]; + export const TIME_UNITS = [ + // Atomic unit. Finest granularity in TensorFlow stat collection. + {symbol: 'µs'}, + // numUnits specifies how many previous units this unit contains. + {symbol: 'ms', numUnits: 1000}, {symbol: 's', numUnits: 1000}, + {symbol: 'min', numUnits: 60}, {symbol: 'hr', numUnits: 60}, + {symbol: 'days', numUnits: 24} + ]; + + /** + * Returns the human readable version of the unit. + * (e.g. 1.35 GB, 23 MB, 34 ms, 6.53 min etc). + */ + export function convertUnitsToHumanReadable(value, units, unitIndex) { + unitIndex = unitIndex == null ? 0 : unitIndex; + if (unitIndex + 1 < units.length && + value >= units[unitIndex + 1].numUnits) { + return tf.graph.util.convertUnitsToHumanReadable( + value / units[unitIndex + 1].numUnits, units, unitIndex + 1); + } + // toPrecision() has the tendency to return a number in scientific + // notation and (number - 0) brings it back to normal notation. + return (value.toPrecision(3) - 0) + ' ' + units[unitIndex].symbol; + } + + export function hasDisplayableNodeStats(stats: NodeStats) { + if (stats && + (stats.totalBytes > 0 || stats.getTotalMicros() > 0 || + stats.outputSize)) { + return true; + } + return false; + } + + /** + * Given a list of strings, it returns a new list of strings with the longest + * common prefix removed. If the common prefix is one of the strings in the + * list, it returns the original strings. + */ + export function removeCommonPrefix(strings: string[]) { + if (strings.length < 2) { + return strings; + } + + let index = 0; + let largestIndex = 0; + // Find the shortest name across all strings. + let minLength = _.min(_.map(strings, str => str.length)); + while (true) { + index++; + let prefixes = _.map(strings, str => str.substring(0, index)); + let allTheSame = prefixes.every((prefix, i) => { + return (i === 0 ? true : prefix === prefixes[i - 1]); + }); + if (allTheSame) { + if (index >= minLength) { + // There is a string whose whole name is a prefix to other string. + // In this case, we return the original list of string. + return strings; + } + largestIndex = index; + } else { + break; + } + } + return _.map(strings, str => str.substring(largestIndex)); + } + + /** + * Given a queryString, aka ?foo=1&bar=2, return the object representation. + */ + export function getQueryParams(queryString: string) { + if (queryString.charAt(0) === '?') { + queryString = queryString.slice(1); + } + + let queryParams = _.chain(queryString.split('&')) + .map((item) => { + if (item) { + return item.split('='); + } + }) + .compact() + .value(); + + return _.object(queryParams); + } +} diff --git a/tensorflow/tensorboard/components/tf_graph_controls_d3v4/BUILD b/tensorflow/tensorboard/components/tf_graph_controls_d3v4/BUILD new file mode 100644 index 0000000000..d5f9a76eb2 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_controls_d3v4/BUILD @@ -0,0 +1,32 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") + +licenses(["notice"]) # Apache 2.0 + +web_library( + name = "tf_graph_controls_d3v4", + srcs = [ + "tf-graph-controls.html", + ], + path = "/tf-graph-controls", + deps = [ + "//tensorflow/tensorboard/components/tf_dashboard_common_d3v4", + "//tensorflow/tensorboard/components/tf_graph_common_d3v4", + "@org_polymer", + "@org_polymer_paper_button", + "@org_polymer_paper_dropdown_menu", + "@org_polymer_paper_menu", + "@org_polymer_paper_radio_group", + "@org_polymer_paper_toggle_button", + "@org_polymer_paper_tooltip", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph_controls_d3v4/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_controls_d3v4/demo/BUILD new file mode 100644 index 0000000000..c47cb90a03 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_controls_d3v4/demo/BUILD @@ -0,0 +1,24 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") + +licenses(["notice"]) # Apache 2.0 + +# bazel run //third_party/tensorflow/tensorboard/components/tf_graph_controls/demo +web_library( + name = "demo", + srcs = ["index.html"], + path = "/tf-graph-controls/demo", + deps = [ + "//tensorflow/tensorboard/components/tf_graph_controls", + "@org_polymer_iron_demo_helpers", + "@org_polymer_paper_styles", + "@org_polymer_webcomponentsjs", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph_controls_d3v4/demo/index.html b/tensorflow/tensorboard/components/tf_graph_controls_d3v4/demo/index.html new file mode 100644 index 0000000000..8b12641b28 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_controls_d3v4/demo/index.html @@ -0,0 +1,49 @@ + + + + + +TF Graph Controls Demo + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_controls_d3v4/tf-graph-controls.html b/tensorflow/tensorboard/components/tf_graph_controls_d3v4/tf-graph-controls.html new file mode 100644 index 0000000000..10faf29bbc --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_controls_d3v4/tf-graph-controls.html @@ -0,0 +1,829 @@ + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_d3v4/BUILD b/tensorflow/tensorboard/components/tf_graph_d3v4/BUILD new file mode 100644 index 0000000000..367baeb67b --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_d3v4/BUILD @@ -0,0 +1,37 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") + +licenses(["notice"]) # Apache 2.0 + +web_library( + name = "tf_graph_d3v4", + srcs = [ + "tf-graph.html", + "tf-graph-minimap.html", + "tf-graph-scene.html", + ], + path = "/tf-graph", + deps = [ + "//tensorflow/tensorboard/components/tf_dashboard_common_d3v4", + "//tensorflow/tensorboard/components/tf_graph_common_d3v4", + "@org_polymer", + "@org_polymer_iron_flex_layout", + "@org_polymer_iron_icons", + "@org_polymer_paper_button", + "@org_polymer_paper_dropdown_menu", + "@org_polymer_paper_input", + "@org_polymer_paper_menu", + "@org_polymer_paper_radio_group", + "@org_polymer_paper_toggle_button", + "@org_polymer_paper_tooltip", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph_d3v4/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_d3v4/demo/BUILD new file mode 100644 index 0000000000..524d0ff767 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_d3v4/demo/BUILD @@ -0,0 +1,26 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") + +licenses(["notice"]) # Apache 2.0 + +# bazel run //third_party/tensorflow/tensorboard/components/tf_graph/demo +web_library( + name = "demo", + srcs = ["index.html"] + glob(["data/**"]), + path = "/tf-graph/demo", + deps = [ + "//tensorflow/tensorboard/components/tf_graph", + "//tensorflow/tensorboard/components/tf_graph_common", + "//tensorflow/tensorboard/components/tf_graph_loader", + "@org_polymer_iron_demo_helpers", + "@org_polymer_paper_styles", + "@org_polymer_webcomponentsjs", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph_d3v4/demo/data/graph.pbtxt b/tensorflow/tensorboard/components/tf_graph_d3v4/demo/data/graph.pbtxt new file mode 100644 index 0000000000..30b2064534 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_d3v4/demo/data/graph.pbtxt @@ -0,0 +1,4606 @@ +node { + name: "GradientDescent/learning_rate" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_3" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.1 + } + } + } +} +node { + name: "gradients/add_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 100 + } + } + } +} +node { + name: "gradients/add_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000d\000\000\000" + } + } + } +} +node { + name: "gradients/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_grad/Shape" + input: "gradients/add_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } +} +node { + name: "gradients/add_1_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "gradients/add_1_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_1_grad/Shape" + input: "gradients/add_1_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } +} +node { + name: "gradients/Reshape_1_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } +} +node { + name: "gradients/Reshape_3_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Maximum/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "gradients/Mean_grad/Const_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "gradients/Mean_grad/Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "gradients/Mean_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/Prod_1" + op: "Prod" + input: "gradients/Mean_grad/Shape_1" + input: "gradients/Mean_grad/Const_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/Mean_grad/Maximum" + op: "Maximum" + input: "gradients/Mean_grad/Prod_1" + input: "gradients/Mean_grad/Maximum/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Prod" + op: "Prod" + input: "gradients/Mean_grad/Shape" + input: "gradients/Mean_grad/Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/Mean_grad/floordiv" + op: "FloorDiv" + input: "gradients/Mean_grad/Prod" + input: "gradients/Mean_grad/Maximum" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Cast" + op: "Cast" + input: "gradients/Mean_grad/floordiv" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Tile/multiples" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Reshape/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "gradients/Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "gradients/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "gradients/Fill" + op: "Fill" + input: "gradients/Shape" + input: "gradients/Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Reshape" + op: "Reshape" + input: "gradients/Fill" + input: "gradients/Mean_grad/Reshape/shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/Tile" + op: "Tile" + input: "gradients/Mean_grad/Reshape" + input: "gradients/Mean_grad/Tile/multiples" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/truediv" + op: "RealDiv" + input: "gradients/Mean_grad/Tile" + input: "gradients/Mean_grad/Cast" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/Reshape_3_grad/Reshape" + op: "Reshape" + input: "gradients/Mean_grad/truediv" + input: "gradients/Reshape_3_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "gradients/Reshape_3_grad/Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 1 + } + } + } + } + } +} +node { + name: "Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "Slice_2/begin" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "Sub_2/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "concat_1/axis" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "concat_1/values_0" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } +} +node { + name: "Slice_1/size" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "Sub_1/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Shape_2" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank_2" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub_1" + op: "Sub" + input: "Rank_2" + input: "Sub_1/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice_1/begin" + op: "Pack" + input: "Sub_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice_1" + op: "Slice" + input: "Shape_2" + input: "Slice_1/begin" + input: "Slice_1/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "concat_1" + op: "ConcatV2" + input: "concat_1/values_0" + input: "Slice_1" + input: "concat_1/axis" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } +} +node { + name: "concat/axis" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "concat/values_0" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } +} +node { + name: "Slice/size" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "Sub/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub" + op: "Sub" + input: "Rank_1" + input: "Sub/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice/begin" + op: "Pack" + input: "Sub" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice" + op: "Slice" + input: "Shape_1" + input: "Slice/begin" + input: "Slice/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "concat" + op: "ConcatV2" + input: "concat/values_0" + input: "Slice" + input: "concat/axis" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } +} +node { + name: "Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub_2" + op: "Sub" + input: "Rank" + input: "Sub_2/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice_2/size" + op: "Pack" + input: "Sub_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice_2" + op: "Slice" + input: "Shape" + input: "Slice_2/begin" + input: "Slice_2/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "logits_biases" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "logits_biases/read" + op: "Identity" + input: "logits_biases" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "logits_weights" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "logits_weights/read" + op: "Identity" + input: "logits_weights" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "hidden_biases" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "hidden_biases/read" + op: "Identity" + input: "hidden_biases" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "hidden_weights" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "hidden_weights/read" + op: "Identity" + input: "hidden_weights" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "Reshape/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\377\377\377\377" + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/depth" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 10 + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/off_value" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0 + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/on_value" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 200 + } + } + } +} +node { + name: "mnist_dataset_train_1/random_shuffle_queue" + op: "RandomShuffleQueueV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "capacity" + value { + i: 20000 + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + type: DT_INT64 + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "min_after_dequeue" + value { + i: 4000 + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + attr { + key: "shapes" + value { + list { + shape { + dim { + size: 28 + } + dim { + size: 28 + } + dim { + size: 1 + } + } + shape { + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" + op: "QueueDequeueManyV2" + input: "mnist_dataset_train_1/random_shuffle_queue" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + shape { + unknown_rank: true + } + } + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + type: DT_INT64 + } + } + } + attr { + key: "timeout_ms" + value { + i: -1 + } + } +} +node { + name: "Reshape" + op: "Reshape" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" + input: "Reshape/shape" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: -1 + } + } + } + } + } +} +node { + name: "MatMul" + op: "MatMul" + input: "Reshape" + input: "hidden_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "add" + op: "Add" + input: "MatMul" + input: "hidden_biases/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "Relu" + op: "Relu" + input: "add" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "MatMul_1" + op: "MatMul" + input: "Relu" + input: "logits_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "add_1" + op: "Add" + input: "MatMul_1" + input: "logits_biases/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "Reshape_1" + op: "Reshape" + input: "add_1" + input: "concat" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot" + op: "OneHot" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany:1" + input: "mnist_dataset_train_2/one_hot/depth" + input: "mnist_dataset_train_2/one_hot/on_value" + input: "mnist_dataset_train_2/one_hot/off_value" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "TI" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "axis" + value { + i: -1 + } + } +} +node { + name: "Reshape_2" + op: "Reshape" + input: "mnist_dataset_train_2/one_hot" + input: "concat_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "SoftmaxCrossEntropyWithLogits" + op: "SoftmaxCrossEntropyWithLogits" + input: "Reshape_1" + input: "Reshape_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" + op: "PreventGradient" + input: "SoftmaxCrossEntropyWithLogits:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "message" + value { + s: "Currently there is no way to take the second derivative of softmax_cross_entropy_with_logits due to the fused implementation\'s interaction with tf.gradients()" + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/Reshape_1_grad/Reshape" + op: "Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + input: "gradients/Reshape_1_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Sum_1" + op: "Sum" + input: "gradients/Reshape_1_grad/Reshape" + input: "gradients/add_1_grad/BroadcastGradientArgs:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_1_grad/Sum_1" + input: "gradients/add_1_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Sum" + op: "Sum" + input: "gradients/Reshape_1_grad/Reshape" + input: "gradients/add_1_grad/BroadcastGradientArgs" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/Reshape" + op: "Reshape" + input: "gradients/add_1_grad/Sum" + input: "gradients/add_1_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_1_grad/Reshape" + input: "^gradients/add_1_grad/Reshape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/add_1_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_1_grad/Reshape_1" + input: "^gradients/add_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_1_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "GradientDescent/update_logits_biases/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "logits_biases" + input: "GradientDescent/learning_rate" + input: "gradients/add_1_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_1_grad/Reshape" + input: "^gradients/add_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_1_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/MatMul_1_grad/MatMul_1" + op: "MatMul" + input: "Relu" + input: "gradients/add_1_grad/tuple/control_dependency" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_1_grad/MatMul" + op: "MatMul" + input: "gradients/add_1_grad/tuple/control_dependency" + input: "logits_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_1_grad/MatMul" + input: "^gradients/MatMul_1_grad/MatMul_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_1_grad/MatMul_1" + input: "^gradients/MatMul_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_1_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "GradientDescent/update_logits_weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "logits_weights" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_1_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/MatMul_1_grad/MatMul" + input: "^gradients/MatMul_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_1_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/Relu_grad/ReluGrad" + op: "ReluGrad" + input: "gradients/MatMul_1_grad/tuple/control_dependency" + input: "Relu" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/Sum_1" + op: "Sum" + input: "gradients/Relu_grad/ReluGrad" + input: "gradients/add_grad/BroadcastGradientArgs:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_grad/Sum_1" + input: "gradients/add_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/Sum" + op: "Sum" + input: "gradients/Relu_grad/ReluGrad" + input: "gradients/add_grad/BroadcastGradientArgs" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/Reshape" + op: "Reshape" + input: "gradients/add_grad/Sum" + input: "gradients/add_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_grad/Reshape" + input: "^gradients/add_grad/Reshape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_grad/Reshape_1" + input: "^gradients/add_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "GradientDescent/update_hidden_biases/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "hidden_biases" + input: "GradientDescent/learning_rate" + input: "gradients/add_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_grad/Reshape" + input: "^gradients/add_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/MatMul_grad/MatMul_1" + op: "MatMul" + input: "Reshape" + input: "gradients/add_grad/tuple/control_dependency" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_grad/MatMul" + op: "MatMul" + input: "gradients/add_grad/tuple/control_dependency" + input: "hidden_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } +} +node { + name: "gradients/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/MatMul_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_grad/MatMul_1" + input: "^gradients/MatMul_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "GradientDescent/update_hidden_weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "hidden_weights" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "GradientDescent" + op: "NoOp" + input: "^GradientDescent/update_hidden_weights/ApplyGradientDescent" + input: "^GradientDescent/update_hidden_biases/ApplyGradientDescent" + input: "^GradientDescent/update_logits_weights/ApplyGradientDescent" + input: "^GradientDescent/update_logits_biases/ApplyGradientDescent" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_2" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "Reshape_3" + op: "Reshape" + input: "SoftmaxCrossEntropyWithLogits" + input: "Slice_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "Mean" + op: "Mean" + input: "Reshape_3" + input: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "_send_Mean_0" + op: "_Send" + input: "Mean" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "client_terminated" + value { + b: true + } + } + attr { + key: "recv_device" + value { + s: "/job:localhost/replica:0/task:0/cpu:0" + } + } + attr { + key: "send_device" + value { + s: "/job:localhost/replica:0/task:0/cpu:0" + } + } + attr { + key: "send_device_incarnation" + value { + i: -5924635994370253548 + } + } + attr { + key: "tensor_name" + value { + s: "Mean:0" + } + } +} +library { +} +versions { + producer: 21 +} diff --git a/tensorflow/tensorboard/components/tf_graph_d3v4/demo/index.html b/tensorflow/tensorboard/components/tf_graph_d3v4/demo/index.html new file mode 100644 index 0000000000..52e2f0b934 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_d3v4/demo/index.html @@ -0,0 +1,92 @@ + + + + + + + +TF Graph Demo + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_d3v4/tf-graph-minimap.html b/tensorflow/tensorboard/components/tf_graph_d3v4/tf-graph-minimap.html new file mode 100644 index 0000000000..5fc16c0520 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_d3v4/tf-graph-minimap.html @@ -0,0 +1,88 @@ + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_d3v4/tf-graph-scene.html b/tensorflow/tensorboard/components/tf_graph_d3v4/tf-graph-scene.html new file mode 100644 index 0000000000..95d9d16f85 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_d3v4/tf-graph-scene.html @@ -0,0 +1,1052 @@ + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_d3v4/tf-graph.html b/tensorflow/tensorboard/components/tf_graph_d3v4/tf-graph.html new file mode 100644 index 0000000000..efbf065a40 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_d3v4/tf-graph.html @@ -0,0 +1,316 @@ + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/BUILD b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/BUILD new file mode 100644 index 0000000000..0cee324f48 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/BUILD @@ -0,0 +1,30 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") + +licenses(["notice"]) # Apache 2.0 + +web_library( + name = "tf_graph_dashboard_d3v4", + srcs = [ + "tf-graph-dashboard.html", + ], + path = "/tf-graph-dashboard", + deps = [ + "//tensorflow/tensorboard/components/tf_backend_d3v4", + "//tensorflow/tensorboard/components/tf_dashboard_common_d3v4", + "//tensorflow/tensorboard/components/tf_graph_board_d3v4", + "//tensorflow/tensorboard/components/tf_graph_controls_d3v4", + "//tensorflow/tensorboard/components/tf_graph_d3v4", + "//tensorflow/tensorboard/components/tf_graph_loader_d3v4", + "@org_polymer", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo/BUILD new file mode 100644 index 0000000000..74238d78e2 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo/BUILD @@ -0,0 +1,24 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") + +licenses(["notice"]) # Apache 2.0 + +# bazel run //third_party/tensorflow/tensorboard/components/tf_graph_dashboard/demo +web_library( + name = "demo", + srcs = ["index.html"] + glob(["data/**"]), + path = "/tf-graph-dashboard/demo", + deps = [ + "//tensorflow/tensorboard/components/tf_graph_dashboard_d3v4", + "@org_polymer_iron_demo_helpers", + "@org_polymer_paper_styles", + "@org_polymer_webcomponentsjs", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo/data/graph_run_run1.pbtxt b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo/data/graph_run_run1.pbtxt new file mode 100644 index 0000000000..30b2064534 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo/data/graph_run_run1.pbtxt @@ -0,0 +1,4606 @@ +node { + name: "GradientDescent/learning_rate" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_3" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.1 + } + } + } +} +node { + name: "gradients/add_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 100 + } + } + } +} +node { + name: "gradients/add_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000d\000\000\000" + } + } + } +} +node { + name: "gradients/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_grad/Shape" + input: "gradients/add_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } +} +node { + name: "gradients/add_1_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "gradients/add_1_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_1_grad/Shape" + input: "gradients/add_1_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } +} +node { + name: "gradients/Reshape_1_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } +} +node { + name: "gradients/Reshape_3_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Maximum/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "gradients/Mean_grad/Const_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "gradients/Mean_grad/Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "gradients/Mean_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/Prod_1" + op: "Prod" + input: "gradients/Mean_grad/Shape_1" + input: "gradients/Mean_grad/Const_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/Mean_grad/Maximum" + op: "Maximum" + input: "gradients/Mean_grad/Prod_1" + input: "gradients/Mean_grad/Maximum/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Prod" + op: "Prod" + input: "gradients/Mean_grad/Shape" + input: "gradients/Mean_grad/Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/Mean_grad/floordiv" + op: "FloorDiv" + input: "gradients/Mean_grad/Prod" + input: "gradients/Mean_grad/Maximum" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Cast" + op: "Cast" + input: "gradients/Mean_grad/floordiv" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Tile/multiples" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Reshape/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "gradients/Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "gradients/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "gradients/Fill" + op: "Fill" + input: "gradients/Shape" + input: "gradients/Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Reshape" + op: "Reshape" + input: "gradients/Fill" + input: "gradients/Mean_grad/Reshape/shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/Tile" + op: "Tile" + input: "gradients/Mean_grad/Reshape" + input: "gradients/Mean_grad/Tile/multiples" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/truediv" + op: "RealDiv" + input: "gradients/Mean_grad/Tile" + input: "gradients/Mean_grad/Cast" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/Reshape_3_grad/Reshape" + op: "Reshape" + input: "gradients/Mean_grad/truediv" + input: "gradients/Reshape_3_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "gradients/Reshape_3_grad/Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 1 + } + } + } + } + } +} +node { + name: "Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "Slice_2/begin" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "Sub_2/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "concat_1/axis" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "concat_1/values_0" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } +} +node { + name: "Slice_1/size" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "Sub_1/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Shape_2" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank_2" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub_1" + op: "Sub" + input: "Rank_2" + input: "Sub_1/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice_1/begin" + op: "Pack" + input: "Sub_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice_1" + op: "Slice" + input: "Shape_2" + input: "Slice_1/begin" + input: "Slice_1/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "concat_1" + op: "ConcatV2" + input: "concat_1/values_0" + input: "Slice_1" + input: "concat_1/axis" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } +} +node { + name: "concat/axis" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "concat/values_0" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } +} +node { + name: "Slice/size" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "Sub/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub" + op: "Sub" + input: "Rank_1" + input: "Sub/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice/begin" + op: "Pack" + input: "Sub" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice" + op: "Slice" + input: "Shape_1" + input: "Slice/begin" + input: "Slice/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "concat" + op: "ConcatV2" + input: "concat/values_0" + input: "Slice" + input: "concat/axis" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } +} +node { + name: "Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub_2" + op: "Sub" + input: "Rank" + input: "Sub_2/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice_2/size" + op: "Pack" + input: "Sub_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice_2" + op: "Slice" + input: "Shape" + input: "Slice_2/begin" + input: "Slice_2/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "logits_biases" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "logits_biases/read" + op: "Identity" + input: "logits_biases" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "logits_weights" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "logits_weights/read" + op: "Identity" + input: "logits_weights" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "hidden_biases" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "hidden_biases/read" + op: "Identity" + input: "hidden_biases" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "hidden_weights" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "hidden_weights/read" + op: "Identity" + input: "hidden_weights" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "Reshape/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\377\377\377\377" + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/depth" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 10 + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/off_value" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0 + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/on_value" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 200 + } + } + } +} +node { + name: "mnist_dataset_train_1/random_shuffle_queue" + op: "RandomShuffleQueueV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "capacity" + value { + i: 20000 + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + type: DT_INT64 + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "min_after_dequeue" + value { + i: 4000 + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + attr { + key: "shapes" + value { + list { + shape { + dim { + size: 28 + } + dim { + size: 28 + } + dim { + size: 1 + } + } + shape { + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" + op: "QueueDequeueManyV2" + input: "mnist_dataset_train_1/random_shuffle_queue" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + shape { + unknown_rank: true + } + } + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + type: DT_INT64 + } + } + } + attr { + key: "timeout_ms" + value { + i: -1 + } + } +} +node { + name: "Reshape" + op: "Reshape" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" + input: "Reshape/shape" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: -1 + } + } + } + } + } +} +node { + name: "MatMul" + op: "MatMul" + input: "Reshape" + input: "hidden_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "add" + op: "Add" + input: "MatMul" + input: "hidden_biases/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "Relu" + op: "Relu" + input: "add" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "MatMul_1" + op: "MatMul" + input: "Relu" + input: "logits_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "add_1" + op: "Add" + input: "MatMul_1" + input: "logits_biases/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "Reshape_1" + op: "Reshape" + input: "add_1" + input: "concat" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot" + op: "OneHot" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany:1" + input: "mnist_dataset_train_2/one_hot/depth" + input: "mnist_dataset_train_2/one_hot/on_value" + input: "mnist_dataset_train_2/one_hot/off_value" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "TI" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "axis" + value { + i: -1 + } + } +} +node { + name: "Reshape_2" + op: "Reshape" + input: "mnist_dataset_train_2/one_hot" + input: "concat_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "SoftmaxCrossEntropyWithLogits" + op: "SoftmaxCrossEntropyWithLogits" + input: "Reshape_1" + input: "Reshape_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" + op: "PreventGradient" + input: "SoftmaxCrossEntropyWithLogits:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "message" + value { + s: "Currently there is no way to take the second derivative of softmax_cross_entropy_with_logits due to the fused implementation\'s interaction with tf.gradients()" + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/Reshape_1_grad/Reshape" + op: "Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + input: "gradients/Reshape_1_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Sum_1" + op: "Sum" + input: "gradients/Reshape_1_grad/Reshape" + input: "gradients/add_1_grad/BroadcastGradientArgs:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_1_grad/Sum_1" + input: "gradients/add_1_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Sum" + op: "Sum" + input: "gradients/Reshape_1_grad/Reshape" + input: "gradients/add_1_grad/BroadcastGradientArgs" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/Reshape" + op: "Reshape" + input: "gradients/add_1_grad/Sum" + input: "gradients/add_1_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_1_grad/Reshape" + input: "^gradients/add_1_grad/Reshape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/add_1_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_1_grad/Reshape_1" + input: "^gradients/add_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_1_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "GradientDescent/update_logits_biases/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "logits_biases" + input: "GradientDescent/learning_rate" + input: "gradients/add_1_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_1_grad/Reshape" + input: "^gradients/add_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_1_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/MatMul_1_grad/MatMul_1" + op: "MatMul" + input: "Relu" + input: "gradients/add_1_grad/tuple/control_dependency" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_1_grad/MatMul" + op: "MatMul" + input: "gradients/add_1_grad/tuple/control_dependency" + input: "logits_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_1_grad/MatMul" + input: "^gradients/MatMul_1_grad/MatMul_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_1_grad/MatMul_1" + input: "^gradients/MatMul_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_1_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "GradientDescent/update_logits_weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "logits_weights" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_1_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/MatMul_1_grad/MatMul" + input: "^gradients/MatMul_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_1_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/Relu_grad/ReluGrad" + op: "ReluGrad" + input: "gradients/MatMul_1_grad/tuple/control_dependency" + input: "Relu" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/Sum_1" + op: "Sum" + input: "gradients/Relu_grad/ReluGrad" + input: "gradients/add_grad/BroadcastGradientArgs:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_grad/Sum_1" + input: "gradients/add_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/Sum" + op: "Sum" + input: "gradients/Relu_grad/ReluGrad" + input: "gradients/add_grad/BroadcastGradientArgs" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/Reshape" + op: "Reshape" + input: "gradients/add_grad/Sum" + input: "gradients/add_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_grad/Reshape" + input: "^gradients/add_grad/Reshape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_grad/Reshape_1" + input: "^gradients/add_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "GradientDescent/update_hidden_biases/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "hidden_biases" + input: "GradientDescent/learning_rate" + input: "gradients/add_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_grad/Reshape" + input: "^gradients/add_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/MatMul_grad/MatMul_1" + op: "MatMul" + input: "Reshape" + input: "gradients/add_grad/tuple/control_dependency" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_grad/MatMul" + op: "MatMul" + input: "gradients/add_grad/tuple/control_dependency" + input: "hidden_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } +} +node { + name: "gradients/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/MatMul_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_grad/MatMul_1" + input: "^gradients/MatMul_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "GradientDescent/update_hidden_weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "hidden_weights" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "GradientDescent" + op: "NoOp" + input: "^GradientDescent/update_hidden_weights/ApplyGradientDescent" + input: "^GradientDescent/update_hidden_biases/ApplyGradientDescent" + input: "^GradientDescent/update_logits_weights/ApplyGradientDescent" + input: "^GradientDescent/update_logits_biases/ApplyGradientDescent" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_2" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "Reshape_3" + op: "Reshape" + input: "SoftmaxCrossEntropyWithLogits" + input: "Slice_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "Mean" + op: "Mean" + input: "Reshape_3" + input: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "_send_Mean_0" + op: "_Send" + input: "Mean" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "client_terminated" + value { + b: true + } + } + attr { + key: "recv_device" + value { + s: "/job:localhost/replica:0/task:0/cpu:0" + } + } + attr { + key: "send_device" + value { + s: "/job:localhost/replica:0/task:0/cpu:0" + } + } + attr { + key: "send_device_incarnation" + value { + i: -5924635994370253548 + } + } + attr { + key: "tensor_name" + value { + s: "Mean:0" + } + } +} +library { +} +versions { + producer: 21 +} diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo/data/runs.json b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo/data/runs.json new file mode 100644 index 0000000000..0429aa71f8 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo/data/runs.json @@ -0,0 +1,6 @@ +{ + "run1": { + "graph": true, + "scalars": ["foo/sin"] + } +} diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo/index.html b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo/index.html new file mode 100644 index 0000000000..67756cc129 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo/index.html @@ -0,0 +1,56 @@ + + + + + + + + +Graph Dashboard Demo + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/tf-graph-dashboard.html b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/tf-graph-dashboard.html new file mode 100644 index 0000000000..bfc52a0a44 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/tf-graph-dashboard.html @@ -0,0 +1,304 @@ + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_info_d3v4/BUILD b/tensorflow/tensorboard/components/tf_graph_info_d3v4/BUILD new file mode 100644 index 0000000000..f84726e7c7 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_info_d3v4/BUILD @@ -0,0 +1,35 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") + +licenses(["notice"]) # Apache 2.0 + +web_library( + name = "tf_graph_info_d3v4", + srcs = [ + "tf-graph-icon.html", + "tf-graph-info.html", + "tf-node-info.html", + "tf-node-list-item.html", + ], + path = "/tf-graph-info", + deps = [ + "//tensorflow/tensorboard/components/tf_dashboard_common_d3v4", + "//tensorflow/tensorboard/components/tf_graph_common_d3v4", + "@org_polymer", + "@org_polymer_iron_collapse", + "@org_polymer_iron_list", + "@org_polymer_paper_icon_button", + "@org_polymer_paper_item", + "@org_polymer_paper_slider", + "@org_polymer_paper_spinner", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph_info_d3v4/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_info_d3v4/demo/BUILD new file mode 100644 index 0000000000..a7d59418fd --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_info_d3v4/demo/BUILD @@ -0,0 +1,26 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") + +licenses(["notice"]) # Apache 2.0 + +# bazel run //third_party/tensorflow/tensorboard/components/tf_graph_info/demo +web_library( + name = "demo", + srcs = ["index.html"] + glob(["data/**"]), + path = "/tf-graph-info/demo", + deps = [ + "//tensorflow/tensorboard/components/tf_graph_common", + "//tensorflow/tensorboard/components/tf_graph_info", + "//tensorflow/tensorboard/components/tf_graph_loader", + "@org_polymer_iron_demo_helpers", + "@org_polymer_paper_styles", + "@org_polymer_webcomponentsjs", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph_info_d3v4/demo/data/graph.pbtxt b/tensorflow/tensorboard/components/tf_graph_info_d3v4/demo/data/graph.pbtxt new file mode 100644 index 0000000000..30b2064534 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_info_d3v4/demo/data/graph.pbtxt @@ -0,0 +1,4606 @@ +node { + name: "GradientDescent/learning_rate" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_3" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.1 + } + } + } +} +node { + name: "gradients/add_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 100 + } + } + } +} +node { + name: "gradients/add_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000d\000\000\000" + } + } + } +} +node { + name: "gradients/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_grad/Shape" + input: "gradients/add_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } +} +node { + name: "gradients/add_1_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "gradients/add_1_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_1_grad/Shape" + input: "gradients/add_1_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } +} +node { + name: "gradients/Reshape_1_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } +} +node { + name: "gradients/Reshape_3_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Maximum/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "gradients/Mean_grad/Const_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "gradients/Mean_grad/Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "gradients/Mean_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/Prod_1" + op: "Prod" + input: "gradients/Mean_grad/Shape_1" + input: "gradients/Mean_grad/Const_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/Mean_grad/Maximum" + op: "Maximum" + input: "gradients/Mean_grad/Prod_1" + input: "gradients/Mean_grad/Maximum/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Prod" + op: "Prod" + input: "gradients/Mean_grad/Shape" + input: "gradients/Mean_grad/Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/Mean_grad/floordiv" + op: "FloorDiv" + input: "gradients/Mean_grad/Prod" + input: "gradients/Mean_grad/Maximum" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Cast" + op: "Cast" + input: "gradients/Mean_grad/floordiv" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Tile/multiples" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Reshape/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "gradients/Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "gradients/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "gradients/Fill" + op: "Fill" + input: "gradients/Shape" + input: "gradients/Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Reshape" + op: "Reshape" + input: "gradients/Fill" + input: "gradients/Mean_grad/Reshape/shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/Tile" + op: "Tile" + input: "gradients/Mean_grad/Reshape" + input: "gradients/Mean_grad/Tile/multiples" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/truediv" + op: "RealDiv" + input: "gradients/Mean_grad/Tile" + input: "gradients/Mean_grad/Cast" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/Reshape_3_grad/Reshape" + op: "Reshape" + input: "gradients/Mean_grad/truediv" + input: "gradients/Reshape_3_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "gradients/Reshape_3_grad/Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 1 + } + } + } + } + } +} +node { + name: "Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "Slice_2/begin" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "Sub_2/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "concat_1/axis" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "concat_1/values_0" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } +} +node { + name: "Slice_1/size" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "Sub_1/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Shape_2" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank_2" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub_1" + op: "Sub" + input: "Rank_2" + input: "Sub_1/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice_1/begin" + op: "Pack" + input: "Sub_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice_1" + op: "Slice" + input: "Shape_2" + input: "Slice_1/begin" + input: "Slice_1/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "concat_1" + op: "ConcatV2" + input: "concat_1/values_0" + input: "Slice_1" + input: "concat_1/axis" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } +} +node { + name: "concat/axis" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "concat/values_0" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } +} +node { + name: "Slice/size" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "Sub/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub" + op: "Sub" + input: "Rank_1" + input: "Sub/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice/begin" + op: "Pack" + input: "Sub" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice" + op: "Slice" + input: "Shape_1" + input: "Slice/begin" + input: "Slice/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "concat" + op: "ConcatV2" + input: "concat/values_0" + input: "Slice" + input: "concat/axis" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } +} +node { + name: "Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub_2" + op: "Sub" + input: "Rank" + input: "Sub_2/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice_2/size" + op: "Pack" + input: "Sub_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice_2" + op: "Slice" + input: "Shape" + input: "Slice_2/begin" + input: "Slice_2/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "logits_biases" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "logits_biases/read" + op: "Identity" + input: "logits_biases" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "logits_weights" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "logits_weights/read" + op: "Identity" + input: "logits_weights" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "hidden_biases" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "hidden_biases/read" + op: "Identity" + input: "hidden_biases" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "hidden_weights" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "hidden_weights/read" + op: "Identity" + input: "hidden_weights" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "Reshape/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\377\377\377\377" + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/depth" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 10 + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/off_value" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0 + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/on_value" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 200 + } + } + } +} +node { + name: "mnist_dataset_train_1/random_shuffle_queue" + op: "RandomShuffleQueueV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "capacity" + value { + i: 20000 + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + type: DT_INT64 + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "min_after_dequeue" + value { + i: 4000 + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + attr { + key: "shapes" + value { + list { + shape { + dim { + size: 28 + } + dim { + size: 28 + } + dim { + size: 1 + } + } + shape { + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" + op: "QueueDequeueManyV2" + input: "mnist_dataset_train_1/random_shuffle_queue" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + shape { + unknown_rank: true + } + } + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + type: DT_INT64 + } + } + } + attr { + key: "timeout_ms" + value { + i: -1 + } + } +} +node { + name: "Reshape" + op: "Reshape" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" + input: "Reshape/shape" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: -1 + } + } + } + } + } +} +node { + name: "MatMul" + op: "MatMul" + input: "Reshape" + input: "hidden_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "add" + op: "Add" + input: "MatMul" + input: "hidden_biases/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "Relu" + op: "Relu" + input: "add" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "MatMul_1" + op: "MatMul" + input: "Relu" + input: "logits_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "add_1" + op: "Add" + input: "MatMul_1" + input: "logits_biases/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "Reshape_1" + op: "Reshape" + input: "add_1" + input: "concat" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot" + op: "OneHot" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany:1" + input: "mnist_dataset_train_2/one_hot/depth" + input: "mnist_dataset_train_2/one_hot/on_value" + input: "mnist_dataset_train_2/one_hot/off_value" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "TI" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "axis" + value { + i: -1 + } + } +} +node { + name: "Reshape_2" + op: "Reshape" + input: "mnist_dataset_train_2/one_hot" + input: "concat_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "SoftmaxCrossEntropyWithLogits" + op: "SoftmaxCrossEntropyWithLogits" + input: "Reshape_1" + input: "Reshape_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" + op: "PreventGradient" + input: "SoftmaxCrossEntropyWithLogits:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "message" + value { + s: "Currently there is no way to take the second derivative of softmax_cross_entropy_with_logits due to the fused implementation\'s interaction with tf.gradients()" + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/Reshape_1_grad/Reshape" + op: "Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + input: "gradients/Reshape_1_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Sum_1" + op: "Sum" + input: "gradients/Reshape_1_grad/Reshape" + input: "gradients/add_1_grad/BroadcastGradientArgs:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_1_grad/Sum_1" + input: "gradients/add_1_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Sum" + op: "Sum" + input: "gradients/Reshape_1_grad/Reshape" + input: "gradients/add_1_grad/BroadcastGradientArgs" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/Reshape" + op: "Reshape" + input: "gradients/add_1_grad/Sum" + input: "gradients/add_1_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_1_grad/Reshape" + input: "^gradients/add_1_grad/Reshape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/add_1_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_1_grad/Reshape_1" + input: "^gradients/add_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_1_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "GradientDescent/update_logits_biases/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "logits_biases" + input: "GradientDescent/learning_rate" + input: "gradients/add_1_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_1_grad/Reshape" + input: "^gradients/add_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_1_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/MatMul_1_grad/MatMul_1" + op: "MatMul" + input: "Relu" + input: "gradients/add_1_grad/tuple/control_dependency" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_1_grad/MatMul" + op: "MatMul" + input: "gradients/add_1_grad/tuple/control_dependency" + input: "logits_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_1_grad/MatMul" + input: "^gradients/MatMul_1_grad/MatMul_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_1_grad/MatMul_1" + input: "^gradients/MatMul_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_1_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "GradientDescent/update_logits_weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "logits_weights" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_1_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/MatMul_1_grad/MatMul" + input: "^gradients/MatMul_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_1_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/Relu_grad/ReluGrad" + op: "ReluGrad" + input: "gradients/MatMul_1_grad/tuple/control_dependency" + input: "Relu" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/Sum_1" + op: "Sum" + input: "gradients/Relu_grad/ReluGrad" + input: "gradients/add_grad/BroadcastGradientArgs:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_grad/Sum_1" + input: "gradients/add_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/Sum" + op: "Sum" + input: "gradients/Relu_grad/ReluGrad" + input: "gradients/add_grad/BroadcastGradientArgs" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/Reshape" + op: "Reshape" + input: "gradients/add_grad/Sum" + input: "gradients/add_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_grad/Reshape" + input: "^gradients/add_grad/Reshape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_grad/Reshape_1" + input: "^gradients/add_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "GradientDescent/update_hidden_biases/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "hidden_biases" + input: "GradientDescent/learning_rate" + input: "gradients/add_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_grad/Reshape" + input: "^gradients/add_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/MatMul_grad/MatMul_1" + op: "MatMul" + input: "Reshape" + input: "gradients/add_grad/tuple/control_dependency" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_grad/MatMul" + op: "MatMul" + input: "gradients/add_grad/tuple/control_dependency" + input: "hidden_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } +} +node { + name: "gradients/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/MatMul_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_grad/MatMul_1" + input: "^gradients/MatMul_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "GradientDescent/update_hidden_weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "hidden_weights" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "GradientDescent" + op: "NoOp" + input: "^GradientDescent/update_hidden_weights/ApplyGradientDescent" + input: "^GradientDescent/update_hidden_biases/ApplyGradientDescent" + input: "^GradientDescent/update_logits_weights/ApplyGradientDescent" + input: "^GradientDescent/update_logits_biases/ApplyGradientDescent" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_2" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "Reshape_3" + op: "Reshape" + input: "SoftmaxCrossEntropyWithLogits" + input: "Slice_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "Mean" + op: "Mean" + input: "Reshape_3" + input: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "_send_Mean_0" + op: "_Send" + input: "Mean" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "client_terminated" + value { + b: true + } + } + attr { + key: "recv_device" + value { + s: "/job:localhost/replica:0/task:0/cpu:0" + } + } + attr { + key: "send_device" + value { + s: "/job:localhost/replica:0/task:0/cpu:0" + } + } + attr { + key: "send_device_incarnation" + value { + i: -5924635994370253548 + } + } + attr { + key: "tensor_name" + value { + s: "Mean:0" + } + } +} +library { +} +versions { + producer: 21 +} diff --git a/tensorflow/tensorboard/components/tf_graph_info_d3v4/demo/index.html b/tensorflow/tensorboard/components/tf_graph_info_d3v4/demo/index.html new file mode 100644 index 0000000000..f7d2ef7ee5 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_info_d3v4/demo/index.html @@ -0,0 +1,94 @@ + + + + + + + +TF Graph Info Demo + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-graph-icon.html b/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-graph-icon.html new file mode 100644 index 0000000000..a3e9dc59c5 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-graph-icon.html @@ -0,0 +1,296 @@ + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-graph-info.html b/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-graph-info.html new file mode 100644 index 0000000000..b33e1e00d0 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-graph-info.html @@ -0,0 +1,354 @@ + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-node-info.html b/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-node-info.html new file mode 100644 index 0000000000..0715777370 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-node-info.html @@ -0,0 +1,651 @@ + + + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-node-list-item.html b/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-node-list-item.html new file mode 100644 index 0000000000..c15478d126 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-node-list-item.html @@ -0,0 +1,138 @@ + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_loader_d3v4/BUILD b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/BUILD new file mode 100644 index 0000000000..7e01811a57 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/BUILD @@ -0,0 +1,25 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") + +licenses(["notice"]) # Apache 2.0 + +web_library( + name = "tf_graph_loader_d3v4", + srcs = [ + "tf-graph-loader.html", + ], + path = "/tf-graph-loader", + deps = [ + "//tensorflow/tensorboard/components/tf_graph_common_d3v4", + "@org_polymer", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph_loader_d3v4/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/demo/BUILD new file mode 100644 index 0000000000..b2fc04b2eb --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/demo/BUILD @@ -0,0 +1,24 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") + +licenses(["notice"]) # Apache 2.0 + +# bazel run //third_party/tensorflow/tensorboard/components/tf_graph_loader/demo +web_library( + name = "demo", + srcs = ["index.html"] + glob(["data/**"]), + path = "/tf-graph-loader/demo", + deps = [ + "//tensorflow/tensorboard/components/tf_graph_loader", + "@org_polymer_iron_demo_helpers", + "@org_polymer_paper_styles", + "@org_polymer_webcomponentsjs", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_graph_loader_d3v4/demo/data/graph.pbtxt b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/demo/data/graph.pbtxt new file mode 100644 index 0000000000..30b2064534 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/demo/data/graph.pbtxt @@ -0,0 +1,4606 @@ +node { + name: "GradientDescent/learning_rate" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_3" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.1 + } + } + } +} +node { + name: "gradients/add_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 100 + } + } + } +} +node { + name: "gradients/add_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000d\000\000\000" + } + } + } +} +node { + name: "gradients/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_grad/Shape" + input: "gradients/add_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } +} +node { + name: "gradients/add_1_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "gradients/add_1_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/add_1_grad/Shape" + input: "gradients/add_1_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } +} +node { + name: "gradients/Reshape_1_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } +} +node { + name: "gradients/Reshape_3_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Maximum/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "gradients/Mean_grad/Const_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "gradients/Mean_grad/Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "gradients/Mean_grad/Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/Prod_1" + op: "Prod" + input: "gradients/Mean_grad/Shape_1" + input: "gradients/Mean_grad/Const_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/Mean_grad/Maximum" + op: "Maximum" + input: "gradients/Mean_grad/Prod_1" + input: "gradients/Mean_grad/Maximum/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Prod" + op: "Prod" + input: "gradients/Mean_grad/Shape" + input: "gradients/Mean_grad/Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/Mean_grad/floordiv" + op: "FloorDiv" + input: "gradients/Mean_grad/Prod" + input: "gradients/Mean_grad/Maximum" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Cast" + op: "Cast" + input: "gradients/Mean_grad/floordiv" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Tile/multiples" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 200 + } + } + } +} +node { + name: "gradients/Mean_grad/Reshape/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "gradients/Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "gradients/Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } +} +node { + name: "gradients/Fill" + op: "Fill" + input: "gradients/Shape" + input: "gradients/Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "gradients/Mean_grad/Reshape" + op: "Reshape" + input: "gradients/Fill" + input: "gradients/Mean_grad/Reshape/shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/Tile" + op: "Tile" + input: "gradients/Mean_grad/Reshape" + input: "gradients/Mean_grad/Tile/multiples" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/Mean_grad/truediv" + op: "RealDiv" + input: "gradients/Mean_grad/Tile" + input: "gradients/Mean_grad/Cast" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/Reshape_3_grad/Reshape" + op: "Reshape" + input: "gradients/Mean_grad/truediv" + input: "gradients/Reshape_3_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "gradients/Reshape_3_grad/Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 1 + } + } + } + } + } +} +node { + name: "Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "Slice_2/begin" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } +} +node { + name: "Sub_2/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "concat_1/axis" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "concat_1/values_0" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } +} +node { + name: "Slice_1/size" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "Sub_1/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Shape_2" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank_2" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub_1" + op: "Sub" + input: "Rank_2" + input: "Sub_1/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice_1/begin" + op: "Pack" + input: "Sub_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice_1" + op: "Slice" + input: "Shape_2" + input: "Slice_1/begin" + input: "Slice_1/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "concat_1" + op: "ConcatV2" + input: "concat_1/values_0" + input: "Slice_1" + input: "concat_1/axis" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } +} +node { + name: "concat/axis" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "concat/values_0" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } +} +node { + name: "Slice/size" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } +} +node { + name: "Sub/y" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Shape_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub" + op: "Sub" + input: "Rank_1" + input: "Sub/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice/begin" + op: "Pack" + input: "Sub" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice" + op: "Slice" + input: "Shape_1" + input: "Slice/begin" + input: "Slice/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "concat" + op: "ConcatV2" + input: "concat/values_0" + input: "Slice" + input: "concat/axis" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } +} +node { + name: "Shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\n\000\000\000" + } + } + } +} +node { + name: "Rank" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "Sub_2" + op: "Sub" + input: "Rank" + input: "Sub_2/y" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "Slice_2/size" + op: "Pack" + input: "Sub_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } +} +node { + name: "Slice_2" + op: "Slice" + input: "Shape" + input: "Slice_2/begin" + input: "Slice_2/size" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } +} +node { + name: "logits_biases" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "logits_biases/read" + op: "Identity" + input: "logits_biases" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "logits_weights" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "logits_weights/read" + op: "Identity" + input: "logits_weights" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "hidden_biases" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "hidden_biases/read" + op: "Identity" + input: "hidden_biases" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "hidden_weights" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "hidden_weights/read" + op: "Identity" + input: "hidden_weights" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "Reshape/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\310\000\000\000\377\377\377\377" + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/depth" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 10 + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/off_value" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0 + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot/on_value" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" + op: "Const" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 200 + } + } + } +} +node { + name: "mnist_dataset_train_1/random_shuffle_queue" + op: "RandomShuffleQueueV2" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "capacity" + value { + i: 20000 + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + type: DT_INT64 + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "min_after_dequeue" + value { + i: 4000 + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + attr { + key: "shapes" + value { + list { + shape { + dim { + size: 28 + } + dim { + size: 28 + } + dim { + size: 1 + } + } + shape { + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" + op: "QueueDequeueManyV2" + input: "mnist_dataset_train_1/random_shuffle_queue" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + shape { + unknown_rank: true + } + } + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + type: DT_INT64 + } + } + } + attr { + key: "timeout_ms" + value { + i: -1 + } + } +} +node { + name: "Reshape" + op: "Reshape" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" + input: "Reshape/shape" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: -1 + } + } + } + } + } +} +node { + name: "MatMul" + op: "MatMul" + input: "Reshape" + input: "hidden_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "add" + op: "Add" + input: "MatMul" + input: "hidden_biases/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "Relu" + op: "Relu" + input: "add" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "MatMul_1" + op: "MatMul" + input: "Relu" + input: "logits_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "add_1" + op: "Add" + input: "MatMul_1" + input: "logits_biases/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "Reshape_1" + op: "Reshape" + input: "add_1" + input: "concat" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "mnist_dataset_train_2/one_hot" + op: "OneHot" + input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany:1" + input: "mnist_dataset_train_2/one_hot/depth" + input: "mnist_dataset_train_2/one_hot/on_value" + input: "mnist_dataset_train_2/one_hot/off_value" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "TI" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "axis" + value { + i: -1 + } + } +} +node { + name: "Reshape_2" + op: "Reshape" + input: "mnist_dataset_train_2/one_hot" + input: "concat_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "SoftmaxCrossEntropyWithLogits" + op: "SoftmaxCrossEntropyWithLogits" + input: "Reshape_1" + input: "Reshape_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" + op: "PreventGradient" + input: "SoftmaxCrossEntropyWithLogits:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "message" + value { + s: "Currently there is no way to take the second derivative of softmax_cross_entropy_with_logits due to the fused implementation\'s interaction with tf.gradients()" + } + } +} +node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/Reshape_1_grad/Reshape" + op: "Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + input: "gradients/Reshape_1_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Sum_1" + op: "Sum" + input: "gradients/Reshape_1_grad/Reshape" + input: "gradients/add_1_grad/BroadcastGradientArgs:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_1_grad/Sum_1" + input: "gradients/add_1_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/Sum" + op: "Sum" + input: "gradients/Reshape_1_grad/Reshape" + input: "gradients/add_1_grad/BroadcastGradientArgs" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/Reshape" + op: "Reshape" + input: "gradients/add_1_grad/Sum" + input: "gradients/add_1_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/add_1_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_1_grad/Reshape" + input: "^gradients/add_1_grad/Reshape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/add_1_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_1_grad/Reshape_1" + input: "^gradients/add_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_1_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "GradientDescent/update_logits_biases/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "logits_biases" + input: "GradientDescent/learning_rate" + input: "gradients/add_1_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/add_1_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_1_grad/Reshape" + input: "^gradients/add_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_1_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "gradients/MatMul_1_grad/MatMul_1" + op: "MatMul" + input: "Relu" + input: "gradients/add_1_grad/tuple/control_dependency" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_1_grad/MatMul" + op: "MatMul" + input: "gradients/add_1_grad/tuple/control_dependency" + input: "logits_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_1_grad/MatMul" + input: "^gradients/MatMul_1_grad/MatMul_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_1_grad/MatMul_1" + input: "^gradients/MatMul_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_1_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } +} +node { + name: "GradientDescent/update_logits_weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "logits_weights" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_1_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@logits_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_1_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/MatMul_1_grad/MatMul" + input: "^gradients/MatMul_1_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_1_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/Relu_grad/ReluGrad" + op: "ReluGrad" + input: "gradients/MatMul_1_grad/tuple/control_dependency" + input: "Relu" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/Sum_1" + op: "Sum" + input: "gradients/Relu_grad/ReluGrad" + input: "gradients/add_grad/BroadcastGradientArgs:1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/Reshape_1" + op: "Reshape" + input: "gradients/add_grad/Sum_1" + input: "gradients/add_grad/Shape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/Sum" + op: "Sum" + input: "gradients/Relu_grad/ReluGrad" + input: "gradients/add_grad/BroadcastGradientArgs" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/Reshape" + op: "Reshape" + input: "gradients/add_grad/Sum" + input: "gradients/add_grad/Shape" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/add_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/add_grad/Reshape" + input: "^gradients/add_grad/Reshape_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/add_grad/Reshape_1" + input: "^gradients/add_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "GradientDescent/update_hidden_biases/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "hidden_biases" + input: "GradientDescent/learning_rate" + input: "gradients/add_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_biases" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "gradients/add_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/add_grad/Reshape" + input: "^gradients/add_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "gradients/MatMul_grad/MatMul_1" + op: "MatMul" + input: "Reshape" + input: "gradients/add_grad/tuple/control_dependency" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "gradients/MatMul_grad/MatMul" + op: "MatMul" + input: "gradients/add_grad/tuple/control_dependency" + input: "hidden_weights/read" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } +} +node { + name: "gradients/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/MatMul_grad/MatMul" + input: "^gradients/MatMul_grad/MatMul_1" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "gradients/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/MatMul_grad/MatMul_1" + input: "^gradients/MatMul_grad/tuple/group_deps" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 100 + } + } + } + } + } +} +node { + name: "GradientDescent/update_hidden_weights/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "hidden_weights" + input: "GradientDescent/learning_rate" + input: "gradients/MatMul_grad/tuple/control_dependency_1" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@hidden_weights" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 100 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } +} +node { + name: "GradientDescent" + op: "NoOp" + input: "^GradientDescent/update_hidden_weights/ApplyGradientDescent" + input: "^GradientDescent/update_hidden_biases/ApplyGradientDescent" + input: "^GradientDescent/update_logits_weights/ApplyGradientDescent" + input: "^GradientDescent/update_logits_biases/ApplyGradientDescent" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "_XlaCluster" + value { + s: "cluster_2" + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } +} +node { + name: "Reshape_3" + op: "Reshape" + input: "SoftmaxCrossEntropyWithLogits" + input: "Slice_2" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 200 + } + } + } + } + } +} +node { + name: "Mean" + op: "Mean" + input: "Reshape_3" + input: "Const" + device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_XlaCluster" + value { + s: "cluster_1" + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +node { + name: "_send_Mean_0" + op: "_Send" + input: "Mean" + device: "/job:localhost/replica:0/task:0/cpu:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "client_terminated" + value { + b: true + } + } + attr { + key: "recv_device" + value { + s: "/job:localhost/replica:0/task:0/cpu:0" + } + } + attr { + key: "send_device" + value { + s: "/job:localhost/replica:0/task:0/cpu:0" + } + } + attr { + key: "send_device_incarnation" + value { + i: -5924635994370253548 + } + } + attr { + key: "tensor_name" + value { + s: "Mean:0" + } + } +} +library { +} +versions { + producer: 21 +} diff --git a/tensorflow/tensorboard/components/tf_graph_loader_d3v4/demo/index.html b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/demo/index.html new file mode 100644 index 0000000000..2ffb2a1a59 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/demo/index.html @@ -0,0 +1,75 @@ + + + + + +TF Graph Loader Demo + + + diff --git a/tensorflow/tensorboard/components/tf_graph_loader_d3v4/test/index.html b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/test/index.html new file mode 100644 index 0000000000..c8e2027f42 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/test/index.html @@ -0,0 +1,30 @@ + + + + + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_graph_loader_d3v4/test/loader.ts b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/test/loader.ts new file mode 100644 index 0000000000..fcd9f7b529 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/test/loader.ts @@ -0,0 +1,25 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +suite('graph loader', () => { + let assert = chai.assert; + + test('loader exists', () => { + assert.isTrue(document.getElementById('loader') != null); + }); + + // TODO(bp): write tests. + +}); diff --git a/tensorflow/tensorboard/components/tf_graph_loader_d3v4/tf-graph-loader.html b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/tf-graph-loader.html new file mode 100644 index 0000000000..8d59cbd2aa --- /dev/null +++ b/tensorflow/tensorboard/components/tf_graph_loader_d3v4/tf-graph-loader.html @@ -0,0 +1,184 @@ + + + + + + + + + + -- GitLab From 0a34a7db58276adda9de491a8b83f185fbc30820 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 15 May 2017 16:29:14 -0700 Subject: [PATCH 651/697] - Enable tf.image.decode_image to decode 4 channel PNG images. - Remove lint errors from image_ops_impl.py and decode_image_op_test.py - Switch tf-slim tf-example decoder to use image_ops.decode_image to handle png & jpg dynamically. PiperOrigin-RevId: 156120108 --- .../python/slim/data/tfexample_decoder.py | 48 +++++----------- .../slim/data/tfexample_decoder_test.py | 4 +- .../kernel_tests/decode_image_op_test.py | 29 +++++----- tensorflow/python/ops/image_ops_impl.py | 57 ++++++++++++++----- 4 files changed, 71 insertions(+), 67 deletions(-) diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index 9e8168dfce..f0e028cd77 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -291,9 +291,8 @@ class Image(ItemHandler): channels: the number of channels in the image. dtype: images will be decoded at this bit depth. Different formats support different bit depths. - See tf.image.decode_png, + See tf.image.decode_image, tf.decode_raw, - tf.image.decode_jpeg: only supports tf.uint8 repeated: if False, decodes a single image. If True, decodes a variable number of image strings from a 1D tensor of strings. """ @@ -326,48 +325,29 @@ class Image(ItemHandler): Args: image_buffer: The tensor representing the encoded image tensor. - image_format: The image format for the image in `image_buffer`. + image_format: The image format for the image in `image_buffer`. If image + format is `raw`, all images are expected to be in this format, otherwise + this op can decode a mix of `jpg` and `png` formats. Returns: A tensor that represents decoded image of self._shape, or (?, ?, self._channels) if self._shape is not specified. """ - - def decode_png(): - return image_ops.decode_png( - image_buffer, self._channels, dtype=self._dtype) + def decode_image(): + """Decodes a png or jpg based on the headers.""" + return image_ops.decode_image(image_buffer, self._channels) def decode_raw(): + """Decodes a raw image.""" return parsing_ops.decode_raw(image_buffer, out_type=self._dtype) - def decode_jpg(): - if self._dtype != dtypes.uint8: - raise ValueError( - 'jpeg decoder can only be used to decode to tf.uint8 but %s was ' - 'requested for a jpeg image.' % self._dtype) - return image_ops.decode_jpeg(image_buffer, self._channels) - - # For RGBA images JPEG is not a valid decoder option. - if self._channels > 3: - pred_fn_pairs = { - math_ops.logical_or( - math_ops.equal(image_format, 'raw'), - math_ops.equal(image_format, 'RAW')): decode_raw, - } - default_decoder = decode_png - else: - pred_fn_pairs = { - math_ops.logical_or( - math_ops.equal(image_format, 'png'), - math_ops.equal(image_format, 'PNG')): decode_png, - math_ops.logical_or( - math_ops.equal(image_format, 'raw'), - math_ops.equal(image_format, 'RAW')): decode_raw, - } - default_decoder = decode_jpg - + pred_fn_pairs = { + math_ops.logical_or( + math_ops.equal(image_format, 'raw'), + math_ops.equal(image_format, 'RAW')): decode_raw, + } image = control_flow_ops.case( - pred_fn_pairs, default=default_decoder, exclusive=True) + pred_fn_pairs, default=decode_image, exclusive=True) image.set_shape([None, None, self._channels]) if self._shape is not None: diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py index 6c83f46e11..506f4bd877 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py @@ -228,9 +228,7 @@ class TFExampleDecoderTest(test.TestCase): image_shape = (2, 3, 3) unused_image, serialized_example = self.GenerateImage( image_format='jpeg', image_shape=image_shape) - expected_regex = ('jpeg decoder can only be used to decode to tf.uint8 but ' - '.* was requested for a jpeg image.') - with self.assertRaisesRegexp(ValueError, expected_regex): + with self.assertRaises(TypeError): unused_decoded_image = self.RunDecodeExample( serialized_example, tfexample_decoder.Image(dtype=dtypes.uint16), diff --git a/tensorflow/python/kernel_tests/decode_image_op_test.py b/tensorflow/python/kernel_tests/decode_image_op_test.py index 52f48c3368..b457b5cc86 100644 --- a/tensorflow/python/kernel_tests/decode_image_op_test.py +++ b/tensorflow/python/kernel_tests/decode_image_op_test.py @@ -36,10 +36,10 @@ class DecodeImageOpTest(test.TestCase): def testGif(self): # Read some real GIFs path = os.path.join(prefix_path, "gif", "testdata", "scan.gif") - WIDTH = 20 - HEIGHT = 40 - STRIDE = 5 - shape = (12, HEIGHT, WIDTH, 3) + width = 20 + height = 40 + stride = 5 + shape = (12, height, width, 3) with self.test_session(use_gpu=True) as sess: gif0 = io_ops.read_file(path) @@ -52,13 +52,13 @@ class DecodeImageOpTest(test.TestCase): for frame_idx, frame in enumerate(image0): gt = np.zeros(shape[1:], dtype=np.uint8) - start = frame_idx * STRIDE - end = (frame_idx + 1) * STRIDE - if end <= WIDTH: + start = frame_idx * stride + end = (frame_idx + 1) * stride + if end <= width: gt[:, start:end, :] = 255 else: - start -= WIDTH - end -= WIDTH + start -= width + end -= width gt[start:end, :, :] = 255 self.assertAllClose(frame, gt) @@ -79,11 +79,15 @@ class DecodeImageOpTest(test.TestCase): self.assertEqual(image0.shape, (256, 128, 3)) self.assertAllEqual(image0, image1) + bad_channels = image_ops.decode_image(jpeg0, channels=4) + with self.assertRaises(errors_impl.InvalidArgumentError): + bad_channels.eval() + def testPng(self): # Read some real PNGs, converting to different channel numbers inputs = [(1, "lena_gray.png")] for channels_in, filename in inputs: - for channels in 0, 1, 3: + for channels in 0, 1, 3, 4: with self.test_session(use_gpu=True) as sess: path = os.path.join(prefix_path, "png", "testdata", filename) png0 = io_ops.read_file(path) @@ -100,11 +104,6 @@ class DecodeImageOpTest(test.TestCase): with self.assertRaises(errors_impl.InvalidArgumentError): decode.eval() - def testInvalidChannels(self): - image_bytes = b"unused" - with self.assertRaises(ValueError): - decode = image_ops.decode_image(image_bytes, channels=4) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index cf7be5759e..3e140ce047 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -32,9 +32,9 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import gen_nn_ops -from tensorflow.python.ops import string_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables @@ -619,7 +619,7 @@ def resize_image_with_crop_or_pad(image, target_height, target_width): # Make sure our checks come first, so that error messages are clearer. if _is_tensor(target_height): target_height = control_flow_ops.with_dependencies( - assert_ops, target_height) + assert_ops, target_height) if _is_tensor(target_width): target_width = control_flow_ops.with_dependencies(assert_ops, target_width) @@ -698,9 +698,12 @@ def resize_images(images, `method` can be one of: - * `ResizeMethod.BILINEAR`: [Bilinear interpolation.](https://en.wikipedia.org/wiki/Bilinear_interpolation) - * `ResizeMethod.NEAREST_NEIGHBOR`: [Nearest neighbor interpolation.](https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation) - * `ResizeMethod.BICUBIC`: [Bicubic interpolation.](https://en.wikipedia.org/wiki/Bicubic_interpolation) + * `ResizeMethod.BILINEAR`: [Bilinear interpolation.]( + https://en.wikipedia.org/wiki/Bilinear_interpolation) + * `ResizeMethod.NEAREST_NEIGHBOR`: [Nearest neighbor interpolation.]( + https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation) + * `ResizeMethod.BICUBIC`: [Bicubic interpolation.]( + https://en.wikipedia.org/wiki/Bicubic_interpolation) * `ResizeMethod.AREA`: Area interpolation. Args: @@ -961,6 +964,7 @@ def adjust_contrast(images, contrast_factor): def adjust_gamma(image, gamma=1, gain=1): """Performs Gamma Correction on the input image. + Also known as Power Law Transform. This function transforms the input image pixelwise according to the equation Out = In**gamma after scaling each pixel to the range 0 to 1. @@ -973,6 +977,9 @@ def adjust_gamma(image, gamma=1, gain=1): Returns: A Tensor. Gamma corrected output image. + Raises: + ValueError: If gamma is negative. + Notes: For gamma greater than 1, the histogram will shift towards left and the output image will be darker than the input image. @@ -983,16 +990,17 @@ def adjust_gamma(image, gamma=1, gain=1): [1] http://en.wikipedia.org/wiki/Gamma_correction """ - with ops.op_scope([image, gamma, gain], None, 'adjust_gamma') as name: + with ops.op_scope([image, gamma, gain], None, 'adjust_gamma'): # Convert pixel value to DT_FLOAT for computing adjusted image img = ops.convert_to_tensor(image, name='img', dtype=dtypes.float32) # Keep image dtype for computing the scale of corresponding dtype image = ops.convert_to_tensor(image, name='image') if gamma < 0: - raise ValueError("Gamma should be a non-negative real number") + raise ValueError('Gamma should be a non-negative real number') # scale = max(dtype) - min(dtype) - scale = constant_op.constant(image.dtype.limits[1] - image.dtype.limits[0], dtype=dtypes.float32) + scale = constant_op.constant(image.dtype.limits[1] - image.dtype.limits[0], + dtype=dtypes.float32) # According to the definition of gamma correction adjusted_img = (img / scale) ** gamma * scale * gain @@ -1305,6 +1313,7 @@ def adjust_saturation(image, saturation_factor, name=None): def decode_image(contents, channels=None, name=None): """Convenience function for `decode_gif`, `decode_jpeg`, and `decode_png`. + Detects whether an image is a GIF, JPEG, or PNG, and performs the appropriate operation to convert the input bytes `string` into a `Tensor` of type `uint8`. @@ -1324,35 +1333,53 @@ def decode_image(contents, channels=None, name=None): `Tensor` with type `uint8` with shape `[height, width, num_channels]` for JPEG and PNG images and shape `[num_frames, height, width, 3]` for GIF images. + + Raises: + ValueError: On incorrect number of channels. """ - with ops.name_scope(name, 'decode_image') as scope: - if channels not in (None, 0, 1, 3): - raise ValueError('channels must be in (None, 0, 1, 3)') + with ops.name_scope(name, 'decode_image'): + if channels not in (None, 0, 1, 3, 4): + raise ValueError('channels must be in (None, 0, 1, 3, 4)') substr = string_ops.substr(contents, 0, 3) def _gif(): + """Decodes a GIF image.""" # Create assert op to check that bytes are GIF decodable is_gif = math_ops.equal(substr, b'\x47\x49\x46', name='is_gif') decode_msg = 'Unable to decode bytes as JPEG, PNG, or GIF' assert_decode = control_flow_ops.Assert(is_gif, [decode_msg]) # Create assert to make sure that channels is not set to 1 # Already checked above that channels is in (None, 0, 1, 3) + gif_channels = 0 if channels is None else channels - good_channels = math_ops.not_equal(gif_channels, 1, name='check_channels') + good_channels = math_ops.logical_and( + math_ops.not_equal(gif_channels, 1, name='check_gif_channels'), + math_ops.not_equal(gif_channels, 4, name='check_gif_channels') + ) channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images' assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) with ops.control_dependencies([assert_decode, assert_channels]): return gen_image_ops.decode_gif(contents) def _png(): + """Decodes a PNG image.""" return gen_image_ops.decode_png(contents, channels) def check_png(): + """Checks if an image is PNG.""" is_png = math_ops.equal(substr, b'\211PN', name='is_png') return control_flow_ops.cond(is_png, _png, _gif, name='cond_png') def _jpeg(): - return gen_image_ops.decode_jpeg(contents, channels) + """Decodes a jpeg image.""" + jpeg_channels = 0 if channels is None else channels + good_channels = math_ops.not_equal(jpeg_channels, 4, + name='check_jpeg_channels') + channels_msg = ('Channels must be in (None, 0, 1, 3) when decoding JPEG ' + 'images') + assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) + with ops.control_dependencies([assert_channels]): + return gen_image_ops.decode_jpeg(contents, channels) # Decode normal JPEG images (start with \xff\xd8\xff\xe0) # as well as JPEG images with EXIF data (start with \xff\xd8\xff\xe1). @@ -1424,7 +1451,7 @@ def total_variation(images, name=None): # Calculate the total variation by taking the absolute value of the # pixel-differences and summing over the appropriate axis. - tot_var = math_ops.reduce_sum(math_ops.abs(pixel_dif1), axis=sum_axis) + \ - math_ops.reduce_sum(math_ops.abs(pixel_dif2), axis=sum_axis) + tot_var = (math_ops.reduce_sum(math_ops.abs(pixel_dif1), axis=sum_axis) + + math_ops.reduce_sum(math_ops.abs(pixel_dif2), axis=sum_axis)) return tot_var -- GitLab From 7c1339e32d7c8c3b95fbb11799bcef4795a9b72a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 15 May 2017 16:35:55 -0700 Subject: [PATCH 652/697] Fix some unused variable warnings when IS_MOBILE_PLATFORM is defined. PiperOrigin-RevId: 156120864 --- tensorflow/core/common_runtime/simple_graph_execution_state.cc | 3 ++- tensorflow/core/kernels/lrn_op.cc | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc index 47cf29d403..3806f9f47f 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc @@ -236,11 +236,12 @@ Status SimpleGraphExecutionState::InitBaseGraph( const BuildGraphOptions& options) { const GraphDef* graph_def = &original_graph_def_; +#ifndef IS_MOBILE_PLATFORM GraphDef optimized_graph; + const RewriterConfig& rewrite_options = session_options_->config.graph_options().rewrite_options(); -#ifndef IS_MOBILE_PLATFORM if (grappler::MetaOptimizerEnabled(rewrite_options)) { // Adding this functionalty in steps. The first step is to make sure // we don't break dependencies. The second step will be to turn the diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc index 3435486c95..c905ebc84a 100644 --- a/tensorflow/core/kernels/lrn_op.cc +++ b/tensorflow/core/kernels/lrn_op.cc @@ -79,11 +79,11 @@ struct LaunchLRN { const int rows = static_cast(in.dim_size(1)); const int cols = static_cast(in.dim_size(2)); const int depth = static_cast(in.dim_size(3)); - const int nodes = cols * rows; #if defined(IS_MOBILE_PLATFORM) SingleThreadedLRN(in, batch, rows, cols, depth, output); #else + const int nodes = cols * rows; if (depth > kSingleThreadedLRNDepthCutoff && (beta_ == T(0.5) || beta_ == T(1))) { SingleThreadedLRN(in, batch, rows, cols, depth, output); -- GitLab From a8b19784a7a3f2b21ea604571cb09b11c1d272e8 Mon Sep 17 00:00:00 2001 From: Dandelion Man? Date: Mon, 15 May 2017 16:47:43 -0700 Subject: [PATCH 653/697] Fix graph rendering in d3v4. Note, there are still some bugs with layout/rendering (e.g. try to open/close a node group and things get weird.) PiperOrigin-RevId: 156122210 --- .../tf_graph_common_d3v4/annotation.ts | 18 ++++++--------- .../components/tf_graph_common_d3v4/scene.ts | 22 +++++++++---------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/annotation.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/annotation.ts index 6db0cd5519..bde3829778 100644 --- a/tensorflow/tensorboard/components/tf_graph_common_d3v4/annotation.ts +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/annotation.ts @@ -75,9 +75,7 @@ module tf.graph.scene.annotation { addAnnotationLabel( aGroup, a.node.name, a, Class.Annotation.ELLIPSIS); } - }); - - annotationGroups + }).merge(annotationGroups) .attr( 'class', a => { @@ -202,20 +200,18 @@ function update(aGroup, d: render.RenderNodeInfo, a: render.Annotation, } // label position - aGroup.select('text.' + Class.Annotation.LABEL).transition().attr({ - x: cx + a.dx + (a.isIn ? -1 : 1) * (a.width / 2 + a.labelOffset), - y: d.y + a.dy - }); + aGroup.select('text.' + Class.Annotation.LABEL).transition() + .attr('x', cx + a.dx + (a.isIn ? -1 : 1) * (a.width / 2 + a.labelOffset)) + .attr('y', d.y + a.dy); // Some annotations (such as summary) are represented using a 12x12 image tag. // Purposely omitted units (e.g. pixels) since the images are vector graphics. // If there is an image, we adjust the location of the image to be vertically // centered with the node and horizontally centered between the arrow and the // text label. - aGroup.select('use.summary').transition().attr({ - x: cx + a.dx - 3, - y: d.y + a.dy - 6 - }); + aGroup.select('use.summary').transition() + .attr('x', cx + a.dx - 3) + .attr('y', d.y + a.dy - 6); // Node position (only one of the shape selection will be non-empty.) positionEllipse( diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/scene.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/scene.ts index 06f03e910a..023bc161f5 100644 --- a/tensorflow/tensorboard/components/tf_graph_common_d3v4/scene.ts +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/scene.ts @@ -453,12 +453,11 @@ export function translate(selection, x0: number, y0: number) { */ export function positionRect(rect, cx: number, cy: number, width: number, height: number) { - rect.transition().attr({ - x: cx - width / 2, - y: cy - height / 2, - width: width, - height: height - }); + rect.transition() + .attr('x', cx - width / 2) + .attr('y', cy - height / 2) + .attr('width', width) + .attr('height', height); }; /** @@ -499,12 +498,11 @@ export function positionButton(button, renderNode: render.RenderNodeInfo) { */ export function positionEllipse(ellipse, cx: number, cy: number, width: number, height: number) { - ellipse.transition().attr({ - cx: cx, - cy: cy, - rx: width / 2, - ry: height / 2 - }); + ellipse.transition() + .attr('cx', cx) + .attr('cy', cy) + .attr('rx', width / 2) + .attr('ry', height / 2); }; /** -- GitLab From 60ee61b084f764da0fa67ed74d8458ca2f2150f0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 15 May 2017 16:56:56 -0700 Subject: [PATCH 654/697] Automated g4 rollback of changelist 156112836 PiperOrigin-RevId: 156123287 --- tensorflow/BUILD | 1 - tensorflow/core/BUILD | 1 - tensorflow/core/kernels/neon/BUILD | 43 -- .../core/kernels/neon/depthwiseconv_float.h | 725 ------------------ .../kernels/neon/neon_depthwise_conv_op.cc | 200 ----- tensorflow/core/kernels/neon/types.h | 71 -- tensorflow/python/kernel_tests/BUILD | 15 - .../neon_depthwise_conv_op_test.py | 287 ------- 8 files changed, 1343 deletions(-) delete mode 100644 tensorflow/core/kernels/neon/BUILD delete mode 100644 tensorflow/core/kernels/neon/depthwiseconv_float.h delete mode 100644 tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc delete mode 100644 tensorflow/core/kernels/neon/types.h delete mode 100644 tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py diff --git a/tensorflow/BUILD b/tensorflow/BUILD index a3c0bd9a5b..503ad79a38 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -285,7 +285,6 @@ filegroup( "//tensorflow/core/grappler/utils:all_files", "//tensorflow/core/kernels:all_files", "//tensorflow/core/kernels/hexagon:all_files", - "//tensorflow/core/kernels/neon:all_files", "//tensorflow/core/ops/compat:all_files", "//tensorflow/core/platform/cloud:all_files", "//tensorflow/core/platform/default/build_config:all_files", diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 28ef29d563..a0d56df4aa 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -737,7 +737,6 @@ cc_library( "//tensorflow/core/kernels:array_not_windows", "//tensorflow/core/kernels:math_not_windows", "//tensorflow/core/kernels:quantized_ops", - "//tensorflow/core/kernels/neon:neon_depthwise_conv_op", ]) + if_mkl([ "//tensorflow/core/kernels:mkl_concat_op", "//tensorflow/core/kernels:mkl_conv_op", diff --git a/tensorflow/core/kernels/neon/BUILD b/tensorflow/core/kernels/neon/BUILD deleted file mode 100644 index 7641516e3b..0000000000 --- a/tensorflow/core/kernels/neon/BUILD +++ /dev/null @@ -1,43 +0,0 @@ -# Description: -# Kernel implementations using Neon intrinsics. -# -package( - default_visibility = ["//visibility:public"], - features = ["-parse_headers"], -) - -licenses(["notice"]) # Apache 2.0 - -load( - "//tensorflow:tensorflow.bzl", - "tf_kernel_library", -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - -tf_kernel_library( - name = "neon_depthwise_conv_op", - hdrs = [ - "depthwiseconv_float.h", - "types.h", - ], - prefix = "neon_depthwise_conv_op", - deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:nn_ops_op_lib", - "//tensorflow/core/kernels:ops_util", - "@gemmlowp//:gemmlowp", - ], -) diff --git a/tensorflow/core/kernels/neon/depthwiseconv_float.h b/tensorflow/core/kernels/neon/depthwiseconv_float.h deleted file mode 100644 index acd58a644f..0000000000 --- a/tensorflow/core/kernels/neon/depthwiseconv_float.h +++ /dev/null @@ -1,725 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_ - -#include "public/gemmlowp.h" -#include "tensorflow/core/kernels/neon/types.h" - -#if defined(__ARM_NEON__) || defined(__ARM_NEON) -#define USE_NEON -#include -#endif - -namespace tensorflow { -namespace neon { - -// Implementation of float DepthwiseConv - -template -struct FloatDepthwiseConvKernel {}; - -#ifdef USE_NEON - -template <> -struct FloatDepthwiseConvKernel { - static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const float* input_ptr, int input_ptr_increment, - const float* filter_ptr, float* acc_buffer_ptr) { - // Load the filters - float32x4_t filter[2]; - for (int i = 0; i < 2; i++) { - filter[i] = vld1q_f32(filter_ptr + 4 * i); - } - int outp = 0; - // Handle 2 output pixels at a time. - for (; outp <= num_output_pixels - 2; outp += 2) { - // Load the inputs - float32x4_t input[4]; - for (int i = 0; i < 4; i++) { - input[i] = vld1q_f32(input_ptr + 4 * i); - } - input_ptr += 16; - // Load the accumulators from acc_buffer - float32x4_t acc[4]; - for (int i = 0; i < 4; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - acc[0] = vmlaq_f32(acc[0], input[0], filter[0]); - acc[1] = vmlaq_f32(acc[1], input[1], filter[1]); - acc[2] = vmlaq_f32(acc[2], input[2], filter[0]); - acc[3] = vmlaq_f32(acc[3], input[3], filter[1]); - // Store the accumulators back to acc_buffer - for (int i = 0; i < 4; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 16; - } - // Handle one output pixel at a time. - for (; outp < num_output_pixels; outp++) { - // Load the inputs - float32x4_t input[2]; - for (int i = 0; i < 2; i++) { - input[i] = vld1q_f32(input_ptr + 4 * i); - } - input_ptr += 8; - // Load the accumulators from acc_buffer - float32x4_t acc[2]; - for (int i = 0; i < 2; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - for (int i = 0; i < 2; i++) { - acc[i] = vmlaq_f32(acc[i], input[i], filter[i]); - } - // Store the accumulators back to acc_buffer - for (int i = 0; i < 2; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 8; - } - } -}; - -template <> -struct FloatDepthwiseConvKernel { - static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const float* input_ptr, int input_ptr_increment, - const float* filter_ptr, float* acc_buffer_ptr) { - const float32x2_t filters = vld1_f32(filter_ptr); - const float32x4_t filters_dup2 = vcombine_f32(filters, filters); - int outp = 0; - // Handle 8 output pixels at a time. - for (; outp <= num_output_pixels - 8; outp += 8) { - // Load the inputs - float32x4_t input[4]; - for (int i = 0; i < 4; i++) { - input[i] = vld1q_f32(input_ptr + 4 * i); - } - input_ptr += 16; - // Load the accumulators from acc_buffer - float32x4_t acc[4]; - for (int i = 0; i < 4; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - for (int i = 0; i < 4; i++) { - acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2); - } - // Store the accumulators back to acc_buffer - for (int i = 0; i < 4; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 16; - } - // Handle 4 output pixels at a time. - for (; outp <= num_output_pixels - 4; outp += 4) { - // Load the inputs - float32x4_t input[2]; - for (int i = 0; i < 2; i++) { - input[i] = vld1q_f32(input_ptr + 4 * i); - } - input_ptr += 8; - // Load the accumulators from acc_buffer - float32x4_t acc[2]; - for (int i = 0; i < 2; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - for (int i = 0; i < 2; i++) { - acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2); - } - // Store the accumulators back to acc_buffer - for (int i = 0; i < 2; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 8; - } - // Handle 2 output pixels at a time. - for (; outp <= num_output_pixels - 2; outp += 2) { - // Load the inputs - const float32x4_t input = vld1q_f32(input_ptr); - input_ptr += 4; - // Load the accumulators from acc_buffer - float32x4_t acc = vld1q_f32(acc_buffer_ptr); - // Multiply-accumulate - acc = vmlaq_f32(acc, input, filters_dup2); - // Store the accumulators back to acc_buffer - vst1q_f32(acc_buffer_ptr, acc); - acc_buffer_ptr += 4; - } - // Handle 1 output pixel at a time - for (; outp < num_output_pixels; outp++) { - // Load the inputs - const float32x2_t input = vld1_f32(input_ptr); - input_ptr += 2; - // Load the accumulators from acc_buffer - float32x2_t acc = vld1_f32(acc_buffer_ptr); - // Multiply-accumulate - acc = vmla_f32(acc, input, filters); - // Store the accumulators back to acc_buffer - vst1_f32(acc_buffer_ptr, acc); - acc_buffer_ptr += 2; - } - } -}; - -template <> -struct FloatDepthwiseConvKernel { - static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const float* input_ptr, int input_ptr_increment, - const float* filter_ptr, float* acc_buffer_ptr) { - // Handle one output pixel at a time. - for (int outp = 0; outp < num_output_pixels; outp++) { - const float* local_filter_ptr = filter_ptr; - const float* local_input_ptr = input_ptr; - int ic = 0; - // Handle 16 input channels at a time. - for (; ic <= input_depth - 16; ic += 16) { - // Load the filters - float32x4_t filter[4]; - for (int i = 0; i < 4; i++) { - filter[i] = vld1q_f32(local_filter_ptr + 4 * i); - } - local_filter_ptr += 16; - // Load the inputs - float32x4_t input[4]; - for (int i = 0; i < 4; i++) { - input[i] = vld1q_f32(local_input_ptr + 4 * i); - } - local_input_ptr += 16; - // Load the accumulators from acc_buffer - float32x4_t acc[4]; - for (int i = 0; i < 4; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - for (int i = 0; i < 4; i++) { - acc[i] = vmlaq_f32(acc[i], input[i], filter[i]); - } - // Store the accumulators back to acc_buffer - for (int i = 0; i < 4; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 16; - } - // Handle 4 input channels at a time. - for (; ic <= input_depth - 4; ic += 4) { - // Load the filters - float32x4_t filter; - filter = vld1q_f32(local_filter_ptr); - local_filter_ptr += 4; - // Load the inputs - float32x4_t input; - input = vld1q_f32(local_input_ptr); - local_input_ptr += 4; - // Load the accumulators from acc_buffer - float32x4_t acc; - acc = vld1q_f32(acc_buffer_ptr); - // Multiply-accumulate - acc = vmlaq_f32(acc, input, filter); - // Store the accumulators back to acc_buffer - vst1q_f32(acc_buffer_ptr, acc); - acc_buffer_ptr += 4; - } - // Handle one input channel at a time. - for (; ic < input_depth; ic++) { - const float input_val = *local_input_ptr++; - const float filter_val = *local_filter_ptr++; - *acc_buffer_ptr++ += filter_val * input_val; - } - input_ptr += input_ptr_increment; - } - } -}; - -template <> -struct FloatDepthwiseConvKernel { - static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const float* input_ptr, int input_ptr_increment, - const float* filter_ptr, float* acc_buffer_ptr) { - // Handle one output pixel at a time. - for (int outp = 0; outp < num_output_pixels; outp++) { - const float* local_filter_ptr = filter_ptr; - const float* local_input_ptr = input_ptr; - int ic = 0; - // Handle 2 input channels at a time. - for (; ic <= input_depth - 2; ic += 2) { - // Load the filters - float32x4_t filter[4]; - for (int i = 0; i < 4; i++) { - filter[i] = vld1q_f32(local_filter_ptr + 4 * i); - } - local_filter_ptr += 16; - // Load the inputs - const float32x2_t input = vld1_f32(local_input_ptr); - local_input_ptr += 2; - // Load the accumulators from acc_buffer - float32x4_t acc[4]; - for (int i = 0; i < 4; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - acc[0] = vmlaq_lane_f32(acc[0], filter[0], input, 0); - acc[1] = vmlaq_lane_f32(acc[1], filter[1], input, 0); - acc[2] = vmlaq_lane_f32(acc[2], filter[2], input, 1); - acc[3] = vmlaq_lane_f32(acc[3], filter[3], input, 1); - // Store the accumulators back to acc_buffer - for (int i = 0; i < 4; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 16; - } - // Handle one input channel at a time. - for (; ic < input_depth; ic++) { - // Load the filters - float32x4_t filter[2]; - for (int i = 0; i < 2; i++) { - filter[i] = vld1q_f32(local_filter_ptr + 4 * i); - } - local_filter_ptr += 8; - // Load the inputs - const float input_val = *local_input_ptr++; - // Load the accumulators from acc_buffer - float32x4_t acc[2]; - for (int i = 0; i < 2; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - for (int i = 0; i < 2; i++) { - acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); - } - // Store the accumulators back to acc_buffer - for (int i = 0; i < 2; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 8; - } - input_ptr += input_ptr_increment; - } - } -}; - -template <> -struct FloatDepthwiseConvKernel { - static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const float* input_ptr, int input_ptr_increment, - const float* filter_ptr, float* acc_buffer_ptr) { - // Handle one output pixel at a time. - for (int outp = 0; outp < num_output_pixels; outp++) { - const float* local_filter_ptr = filter_ptr; - const float* local_input_ptr = input_ptr; - int ic = 0; - // Handle 8 input channels at a time. - for (; ic <= input_depth - 8; ic += 8) { - // Load the filters - float32x4_t filter[4]; - for (int i = 0; i < 4; i++) { - filter[i] = vld1q_f32(local_filter_ptr + 4 * i); - } - local_filter_ptr += 16; - // Load the inputs - float32x4x2_t input_dup2[2]; - for (int i = 0; i < 2; i++) { - const float32x4_t input = vld1q_f32(local_input_ptr + 4 * i); - input_dup2[i] = vzipq_f32(input, input); - } - local_input_ptr += 8; - // Load the accumulators from acc_buffer - float32x4_t acc[4]; - for (int i = 0; i < 4; i++) { - acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); - } - // Multiply-accumulate - acc[0] = vmlaq_f32(acc[0], filter[0], input_dup2[0].val[0]); - acc[1] = vmlaq_f32(acc[1], filter[1], input_dup2[0].val[1]); - acc[2] = vmlaq_f32(acc[2], filter[2], input_dup2[1].val[0]); - acc[3] = vmlaq_f32(acc[3], filter[3], input_dup2[1].val[1]); - // Store the accumulators back to acc_buffer - for (int i = 0; i < 4; i++) { - vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); - } - acc_buffer_ptr += 16; - } - // Handle 4 input channels at a time. - for (; ic <= input_depth - 4; ic += 4) { - // Load the filters - float32x2_t filter[4]; - for (int i = 0; i < 4; i++) { - filter[i] = vld1_f32(local_filter_ptr + 2 * i); - } - local_filter_ptr += 8; - // Load the inputs - const float32x4_t input = vld1q_f32(local_input_ptr); - local_input_ptr += 4; - // Load the accumulators from acc_buffer - float32x2_t acc[4]; - for (int i = 0; i < 4; i++) { - acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); - } - // Multiply-accumulate - acc[0] = vmla_lane_f32(acc[0], filter[0], vget_low_f32(input), 0); - acc[1] = vmla_lane_f32(acc[1], filter[1], vget_low_f32(input), 1); - acc[2] = vmla_lane_f32(acc[2], filter[2], vget_high_f32(input), 0); - acc[3] = vmla_lane_f32(acc[3], filter[3], vget_high_f32(input), 1); - // Store the accumulators back to acc_buffer - for (int i = 0; i < 4; i++) { - vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); - } - acc_buffer_ptr += 8; - } - // Handle 2 input channels at a time. - for (; ic <= input_depth - 2; ic += 2) { - // Load the filters - const float32x4_t filter = vld1q_f32(local_filter_ptr); - local_filter_ptr += 4; - // Load the inputs - const float32x2_t input = vld1_f32(local_input_ptr); - local_input_ptr += 2; - // Load the accumulators from acc_buffer - float32x2_t acc[2]; - for (int i = 0; i < 2; i++) { - acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); - } - // Multiply-accumulate - acc[0] = vmla_lane_f32(acc[0], vget_low_f32(filter), input, 0); - acc[1] = vmla_lane_f32(acc[1], vget_high_f32(filter), input, 1); - // Store the accumulators back to acc_buffer - for (int i = 0; i < 2; i++) { - vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); - } - acc_buffer_ptr += 4; - } - // Handle one input channel at a time. - for (; ic < input_depth; ic++) { - // Load the inputs - const float input_val = *local_input_ptr++; - // Multiply-accumulate - for (int i = 0; i < 2; i++) { - acc_buffer_ptr[i] += local_filter_ptr[i] * input_val; - } - local_filter_ptr += 2; - acc_buffer_ptr += 2; - } - input_ptr += input_ptr_increment; - } - } -}; -#endif - -// Accumulates the effect of one row of the filter, on a segment of one row -// of the output, accessing the corresponding one row of the input. -template -void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width, - const float* input_data, int pad_width, - int depth_multiplier, int filter_width, - const float* filter_data, - int out_x_buffer_start, int out_x_buffer_end, - int output_depth, float* acc_buffer) { -#ifdef GEMMLOWP_PROFILING - gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); -#endif - // Sanity check parameters. This is important in particular to ensure - // that we keep the number of template instantiations minimal, so we don't - // increase binary size unnecessarily. - static_assert(kFixedDepthMultiplier || !kFixedInputDepth, ""); - static_assert(kFixedInputDepth || kAllowStrided, ""); - DCHECK(stride == 1 || kAllowStrided); - if (kFixedInputDepth) { - DCHECK_EQ(input_depth, kFixedInputDepth); - } - if (kFixedDepthMultiplier) { - DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier); - } - DCHECK_EQ(output_depth, input_depth * depth_multiplier); - const int input_ptr_increment = stride * input_depth; - const float* filter_base_ptr = filter_data; - for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - // For the current (filter_x, filter_y) point in the filter, - // compute the boundaries of the corresponding output row segment. - int out_x_loop_start_unclampled = 0; - int out_x_loop_end_unclampled = 0; - if (kAllowStrided) { - if (stride == 2) { - out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2; - out_x_loop_end_unclampled = - (pad_width + input_width - filter_x + 1) / 2; - } else if (stride == 4) { - out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4; - out_x_loop_end_unclampled = - (pad_width + input_width - filter_x + 3) / 4; - } else { - out_x_loop_start_unclampled = - (pad_width - filter_x + stride - 1) / stride; - out_x_loop_end_unclampled = - (pad_width + input_width - filter_x + stride - 1) / stride; - } - } else { - out_x_loop_start_unclampled = pad_width - filter_x; - out_x_loop_end_unclampled = pad_width + input_width - filter_x; - } - // The kernel will have to iterate on the segment of the - // output row that starts at out_x_loop_start and out_x_loop_end. - const int out_x_loop_start = - std::max(out_x_buffer_start, out_x_loop_start_unclampled); - const int out_x_loop_end = - std::min(out_x_buffer_end, out_x_loop_end_unclampled); - - float* acc_buffer_ptr = - acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; - const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; - const float* input_ptr = input_data + in_x_origin * input_depth; - const int num_output_pixels = out_x_loop_end - out_x_loop_start; - FloatDepthwiseConvKernel::Run(num_output_pixels, - input_depth, - depth_multiplier, - input_ptr, - input_ptr_increment, - filter_base_ptr, - acc_buffer_ptr); - filter_base_ptr += output_depth; - } -} - -// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized. -inline void FloatDepthwiseConvAccumRowGeneric( - int stride, int input_depth, int input_width, const float* input_data, - int pad_width, int depth_multiplier, int filter_width, - const float* filter_data, int out_x_buffer_start, int out_x_buffer_end, - int output_depth, float* acc_buffer) { - gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)"); - - VLOG(1) << "DepthwiseConv2d using slow path with " - << "stride = " << stride << ", " - << "input_depth = " << input_depth << ", " - << "depth_multiplier = " << depth_multiplier << "."; - - const float* filter_base_ptr = filter_data; - for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - const int out_x_loop_start = std::max( - out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride); - const int out_x_loop_end = - std::min(out_x_buffer_end, - (pad_width + input_width - filter_x + stride - 1) / stride); - - float* acc_buffer_ptr = - acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; - const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; - const float* input_ptr = input_data + in_x_origin * input_depth; - const int input_ptr_increment = (stride - 1) * input_depth; - for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) { - const float* filter_ptr = filter_base_ptr; - for (int ic = 0; ic < input_depth; ++ic) { - const float input_val = *input_ptr++; - for (int m = 0; m < depth_multiplier; m++) { - const float filter_val = *filter_ptr++; - *acc_buffer_ptr++ += filter_val * input_val; - } - } - input_ptr += input_ptr_increment; - } - filter_base_ptr += output_depth; - } -} - -// Initializes the accumulator buffer with bias values. -inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, - const float* bias_data, - float* acc_buffer) { - // TODO(benoitjacob): This might need optimized specializations - // for small output_depth values, if that ever becomes an important - // case (like it was for some quantized DepthwiseConv cases). - for (int i = 0; i < num_output_pixels; i++) { - memcpy(acc_buffer + i * output_depth, bias_data, - sizeof(acc_buffer[0]) * output_depth); - } -} - -template -void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - const float* bias_data, const Dims<4>& bias_dims, int stride, - int pad_width, int pad_height, int depth_multiplier, - float* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("DepthwiseConv"); - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int input_depth = ArraySize(input_dims, 0); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - DCHECK(output_depth == input_depth * depth_multiplier); - - static const int kAccBufferMaxSize = 1024; - float acc_buffer[kAccBufferMaxSize]; - DCHECK_GE(kAccBufferMaxSize, output_depth) - << "Too small kAccBufferMaxSize for this model!"; - const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth; - const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth; - DCHECK_LE(kOutputPixelsInAccBuffer * output_depth, kAccBufferActualSize); - DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize); - DCHECK_GE(kOutputPixelsInAccBuffer, 1); - - // row_accum_func will point to the core accumulation function to be used - // for this DepthwiseConv op. - auto* row_accum_func = FloatDepthwiseConvAccumRowGeneric; - - const int kMaxFixedDepthMultiplier = 8; - int fixed_depth_multiplier = 0; - if (depth_multiplier <= kMaxFixedDepthMultiplier) { - fixed_depth_multiplier = depth_multiplier; - } - // kMaxUnrolling is the max number of output values that we aim to handle - // in one unrolled iteration of the inner loop. For practical performance - // reasons, it is limited by the number of available registers. We could - // fine-tune it depending on the architecture, but that's not worth doing - // since this whole code is not very optimized to begin with. The - // present value reflects what's realistic on ARM 32bit NEON with 16 128-bit - // vector registers. - const int kMaxUnrolling = 8; - int fixed_input_depth = 0; - if (fixed_depth_multiplier && - input_depth * fixed_depth_multiplier <= kMaxUnrolling) { - fixed_input_depth = input_depth; - } -#define TF_NEON_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \ - FIXED_DEPTH_MULTIPLIER) \ - if ((stride == 1 || ALLOW_STRIDED) && \ - fixed_input_depth == FIXED_INPUT_DEPTH && \ - fixed_depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \ - row_accum_func = \ - FloatDepthwiseConvAccumRow; \ - } - -#ifdef USE_NEON - TF_NEON_USE_DEPTHWISECONV_KERNEL(true, 0, 1) - TF_NEON_USE_DEPTHWISECONV_KERNEL(true, 0, 8) - TF_NEON_USE_DEPTHWISECONV_KERNEL(true, 0, 2) - TF_NEON_USE_DEPTHWISECONV_KERNEL(false, 8, 1) - TF_NEON_USE_DEPTHWISECONV_KERNEL(false, 2, 1) -#endif // USE_NEON - -#undef TF_NEON_USE_DEPTHWISECONV_KERNEL - - // Now that we have determined row_accum_func, we can start work. - float* output_ptr = output_data; - for (int b = 0; b < batches; ++b) { - for (int out_y = 0; out_y < output_height; ++out_y) { - const int in_y_origin = (out_y * stride) - pad_height; - const int filter_y_start = std::max(0, -in_y_origin); - const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); - for (int out_x_buffer_start = 0; out_x_buffer_start < output_width; - out_x_buffer_start += kOutputPixelsInAccBuffer) { - const int out_x_buffer_end = std::min( - output_width, out_x_buffer_start + kOutputPixelsInAccBuffer); - // We call a 'pixel' a group of activation that share all but the - // 'depth'/'channel' coordinate. num_output_pixels is the number of - // output pixels that we will accumulate in this loop iteration. - const int num_output_pixels = out_x_buffer_end - out_x_buffer_start; - // Initialize our local accumulator with the bias values, so we don't - // have to add them later. - DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data, - acc_buffer); - // Accumulation loop. Most of the time should be spent in here. - for (int filter_y = filter_y_start; filter_y < filter_y_end; - ++filter_y) { - const int in_y = in_y_origin + filter_y; - row_accum_func(stride, input_depth, input_width, - input_data + in_y * input_dims.strides[2] + - b * input_dims.strides[3], - pad_width, depth_multiplier, filter_width, - filter_data + filter_y * filter_dims.strides[2], - out_x_buffer_start, out_x_buffer_end, output_depth, - acc_buffer); - } - // Finished accumulating. Now store to destination. - const int num_output_values = output_depth * num_output_pixels; - int i = 0; -// TODO(benoitjacob) optimized code goes here -#ifdef USE_NEON - // Handle 16 values at a time - for (; i <= num_output_values - 16; i += 16) { - float32x4_t acc[4]; - for (int k = 0; k < 4; k++) { - acc[k] = vld1q_f32(acc_buffer + i + 4 * k); - } - if (Ac == FusedActivationFunctionType::kRelu) { - for (int k = 0; k < 4; k++) { - acc[k] = vmaxq_f32(vdupq_n_f32(0.f), acc[k]); - } - } else if (Ac == FusedActivationFunctionType::kRelu6) { - for (int k = 0; k < 4; k++) { - acc[k] = vmaxq_f32(vdupq_n_f32(0.f), - vminq_f32(vdupq_n_f32(6.f), acc[k])); - } - } else if (Ac == FusedActivationFunctionType::kRelu1) { - for (int k = 0; k < 4; k++) { - acc[k] = vmaxq_f32(vdupq_n_f32(-1.f), - vminq_f32(vdupq_n_f32(1.f), acc[k])); - } - } - for (int k = 0; k < 4; k++) { - vst1q_f32(output_ptr + 4 * k, acc[k]); - } - output_ptr += 16; - } - // Handle 4 values at a time - for (; i <= num_output_values - 4; i += 4) { - float32x4_t acc = vld1q_f32(acc_buffer + i); - if (Ac == FusedActivationFunctionType::kRelu) { - acc = vmaxq_f32(vdupq_n_f32(0.f), acc); - } else if (Ac == FusedActivationFunctionType::kRelu6) { - acc = vmaxq_f32(vdupq_n_f32(0.f), vminq_f32(vdupq_n_f32(6.f), acc)); - } else if (Ac == FusedActivationFunctionType::kRelu1) { - acc = - vmaxq_f32(vdupq_n_f32(-1.f), vminq_f32(vdupq_n_f32(1.f), acc)); - } - vst1q_f32(output_ptr, acc); - output_ptr += 4; - } -#endif - // Handle leftover values, one by one. This is very slow. - for (; i < num_output_values; i++) { - float acc = acc_buffer[i]; - if (Ac == FusedActivationFunctionType::kRelu) { - acc = std::max(0.f, acc); - } else if (Ac == FusedActivationFunctionType::kRelu6) { - acc = std::max(0.f, std::min(6.f, acc)); - } else if (Ac == FusedActivationFunctionType::kRelu1) { - acc = std::max(-1.f, std::min(1.f, acc)); - } - *output_ptr++ = acc; - } - } - } - } -} - -} // end namespace neon -} // end namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_ diff --git a/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc b/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc deleted file mode 100644 index 54b2a10dd8..0000000000 --- a/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc +++ /dev/null @@ -1,200 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "tensorflow/core/framework/numeric_op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/neon/depthwiseconv_float.h" -#include "tensorflow/core/kernels/ops_util.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mem.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/padding.h" - -namespace tensorflow { - -// A version of tensorflow/core/kernels/depthwise_conv_op.cc that -// uses the neon intrinsics. -class NeonDepthwiseConv2dNativeOp : public BinaryOp { - public: - explicit NeonDepthwiseConv2dNativeOp(OpKernelConstruction* context) - : BinaryOp(context) { - OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); - OP_REQUIRES(context, strides_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); - OP_REQUIRES(context, strides_[1] == strides_[2], - errors::InvalidArgument( - "Current implementation only supports equal length " - "strides in the row and column dimensions.")); - OP_REQUIRES( - context, (strides_[0] == 1 && strides_[3] == 1), - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - } - - void Compute(OpKernelContext* context) override { - const Tensor& input = context->input(0); - const Tensor& filter = context->input(1); - - // For 2D convolution, there should be 4 dimensions. - OP_REQUIRES(context, input.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input.shape().DebugString())); - OP_REQUIRES(context, filter.dims() == 4, - errors::InvalidArgument("filter must be 4-dimensional: ", - filter.shape().DebugString())); - - const int32 in_depth = input.dim_size(3); - OP_REQUIRES( - context, in_depth == filter.dim_size(2), - errors::InvalidArgument("input and filter must have the same depth: ", - in_depth, " vs ", filter.dim_size(2))); - const int32 batch = input.dim_size(0); - const int32 input_rows = input.dim_size(1); - const int32 input_cols = input.dim_size(2); - - const int32 filter_rows = filter.dim_size(0); - const int32 filter_cols = filter.dim_size(1); - const int32 depth_multiplier = filter.dim_size(3); - - const int32 out_depth = in_depth * depth_multiplier; - - const int32 stride = strides_[1]; - - int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; - OP_REQUIRES_OK(context, - GetWindowedOutputSize(input_rows, filter_rows, stride, - padding_, &out_rows, &pad_rows)); - OP_REQUIRES_OK(context, - GetWindowedOutputSize(input_cols, filter_cols, stride, - padding_, &out_cols, &pad_cols)); - TensorShape out_shape({batch, out_rows, out_cols, out_depth}); - OP_REQUIRES( - context, out_shape.num_elements() <= 2147483647, - errors::InvalidArgument("total number of outputs should be within the " - "range of int which is used in the GPU kernel", - in_depth, " vs ", filter.dim_size(2))); - - // Output tensor is of the following dimensions: - // [ in_batch, out_rows, out_cols, out_depth ] - Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); - - VLOG(2) << "NeonDepthwiseConv2dNative: " - << " Input: [" << batch << ", " << input_rows << ", " << input_cols - << ", " << in_depth << "]; Filter: [" << filter_rows << ", " - << filter_cols << ", " << in_depth << ", " << depth_multiplier - << "]; stride = " << stride << ", pad_rows = " << pad_rows - << ", pad_cols = " << pad_cols << ", output: [" << batch << ", " - << out_rows << ", " << out_cols << ", " << out_depth << "]"; - - // If there is nothing to compute, return. - if (out_shape.num_elements() == 0) { - return; - } - - const float* input_ptr = input.template flat().data(); - const float* filter_ptr = filter.template flat().data(); - float* output_ptr = output->template flat().data(); - - auto input_neon_dims = ToNeonDims(input.shape()); - auto filter_neon_dims = FilterToNeonDims(filter.shape()); - auto bias_neon_dims = BiasNeonDims(filter.shape()); - - int64 bias_size = bias_neon_dims.sizes[3] * bias_neon_dims.strides[3]; - float* bias_ptr = static_cast(port::AlignedMalloc( - bias_size * sizeof(float), Allocator::kAllocatorAlignment)); - memset(bias_ptr, 0, bias_size * sizeof(float)); - - neon::DepthwiseConv( - input_ptr, input_neon_dims, filter_ptr, filter_neon_dims, bias_ptr, - bias_neon_dims, stride, pad_cols, pad_rows, depth_multiplier, - output_ptr, ToNeonDims(out_shape)); - - port::AlignedFree(bias_ptr); - } - - private: - void SetNeonDimStrides(neon::Dims<4>* d) { - int64 stride = 1; - for (int i = 0; i < 4; ++i) { - d->strides[i] = stride; - stride *= d->sizes[i]; - } - } - - neon::Dims<4> ToNeonDims(const TensorShape& input) { - // Dims in the neon kernels are channel, x, y, batch order. - neon::Dims<4> result; - result.sizes[0] = input.dim_size(3); - result.sizes[1] = input.dim_size(2); - result.sizes[2] = input.dim_size(1); - result.sizes[3] = input.dim_size(0); - SetNeonDimStrides(&result); - return result; - } - - neon::Dims<4> FilterToNeonDims(const TensorShape& filter) { - // Dims in the neon kernels are channel, x, y, batch order. - neon::Dims<4> result; - result.sizes[0] = filter.dim_size(2) * filter.dim_size(3); - result.sizes[1] = filter.dim_size(1); - result.sizes[2] = filter.dim_size(0); - result.sizes[3] = 1; - SetNeonDimStrides(&result); - - return result; - } - - neon::Dims<4> BiasNeonDims(const TensorShape& filter) { - // Dims in the neon kernels are channel, x, y, batch order. - // Bias has only output channel set. - neon::Dims<4> result; - result.sizes[0] = filter.dim_size(3); // output channels - result.sizes[1] = 1; - result.sizes[2] = 1; - result.sizes[3] = 1; - SetNeonDimStrides(&result); - - return result; - } - - std::vector strides_; - Padding padding_; - - TF_DISALLOW_COPY_AND_ASSIGN(NeonDepthwiseConv2dNativeOp); -}; - -#define REGISTER_CPU_KERNEL(T) \ - REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label("neon"), \ - NeonDepthwiseConv2dNativeOp); - -TF_CALL_float(REGISTER_CPU_KERNEL); - -} // namespace tensorflow diff --git a/tensorflow/core/kernels/neon/types.h b/tensorflow/core/kernels/neon/types.h deleted file mode 100644 index e258ee0dfb..0000000000 --- a/tensorflow/core/kernels/neon/types.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_ - -namespace tensorflow { -namespace neon { - -enum class FusedActivationFunctionType { kNone, kRelu6, kRelu1, kRelu }; - -template -struct Dims { - int sizes[N]; - int strides[N]; -}; - -inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { - DCHECK(i0 >= 0 && i0 < dims.sizes[0]); - DCHECK(i1 >= 0 && i1 < dims.sizes[1]); - DCHECK(i2 >= 0 && i2 < dims.sizes[2]); - DCHECK(i3 >= 0 && i3 < dims.sizes[3]); - return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] + - i3 * dims.strides[3]; -} - -// Get array size, DCHECKing that the dim index is in range. -template -int ArraySize(const Dims& array, int index) { - DCHECK(index >= 0 && index < N); - return array.sizes[index]; -} - -// Get common array size, DCHECKing that they all agree. -template -int MatchingArraySize(const ArrayType1& array1, int index1, - const ArrayType2& array2, int index2) { - DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); - return ArraySize(array1, index1); -} - -template -int MatchingArraySize(const ArrayType1& array1, int index1, - const ArrayType2& array2, int index2, Args... args) { - DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); - return MatchingArraySize(array1, index1, args...); -} - -inline int RequiredBufferSizeForDims(const Dims<4>& dims) { - int max_offset = 0; - for (int i = 0; i < 4; i++) { - max_offset += (dims.sizes[i] - 1) * dims.strides[i]; - } - return max_offset + 1; -} - -} // end namespace neon -} // end namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_ diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index fb886aaf7d..6689d6a6b4 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2087,21 +2087,6 @@ cuda_py_test( ], ) -tf_py_test( - name = "neon_depthwise_conv_op_test", - size = "medium", - srcs = ["neon_depthwise_conv_op_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:array_ops", - "//tensorflow/python:nn", - "//tensorflow/python:nn_grad", - "//tensorflow/python:nn_ops", - ], -) - cuda_py_test( name = "division_future_test", size = "medium", diff --git a/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py deleted file mode 100644 index 30795eed8a..0000000000 --- a/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functional tests for neon kernel for depthwise convolutional operations.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import nn_impl -from tensorflow.python.ops import nn_ops -import tensorflow.python.ops.nn_grad # pylint: disable=unused-import -from tensorflow.python.platform import test - - -def ConfigsToTest(): - """Iterator for different convolution shapes, strides and paddings. - - Yields: - Tuple (input_size, filter_size, out_size, stride, padding), the depthwise - convolution parameters. - """ - input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 35, 35, 2], - [4, 147, 147, 2], [3, 299, 299, 3], [5, 183, 183, 1]] - filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [5, 5, 2, 1], - [3, 3, 2, 8], [2, 2, 3, 8], [5, 5, 1, 2]] - out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 35, 35, 2], - [4, 49, 49, 16], [3, 150, 150, 24], [5, 92, 92, 2]] - strides = [1, 1, 1, 1, 3, 2, 2] - # pylint: disable=invalid-name - VALID = "VALID" - SAME = "SAME" - # pylint: enable=invalid-name - paddings = [SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME] - for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides, - paddings): - yield i, f, o, s, p - - -def CheckGradConfigsToTest(): - """Iterator for different convolution shapes, strides and paddings. - - compute_gradient_error() is very expensive. So the configs should be - relatively small. - - Yields: - Tuple (input_size, filter_size, out_size, stride, padding), the depthwise - convolution parameters. - """ - input_sizes = [[2, 5, 8, 1], [4, 5, 5, 1], [2, 4, 4, 2], [1, 15, 15, 2], - [2, 15, 16, 1]] - filter_sizes = [[4, 4, 1, 2], [2, 2, 1, 2], [3, 1, 2, 2], [1, 3, 2, 1], - [3, 3, 1, 2]] - out_sizes = [[2, 5, 8, 2], [4, 2, 2, 2], [2, 4, 4, 4], [1, 15, 15, 2], - [2, 5, 5, 2]] - strides = [1, 2, 1, 1, 3] - # pylint: disable=invalid-name - VALID = "VALID" - SAME = "SAME" - # pylint: enable=invalid-name - paddings = [SAME, VALID, SAME, SAME, VALID] - for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides, - paddings): - yield i, f, o, s, p - - -class DepthwiseConv2DTest(test.TestCase): - - # This is testing that depthwise_conv2d and depthwise_conv2d_native - # produce the same results. It also tests that NCHW and NWHC - # formats agree, by comparing the depthwise_conv2d_native with - # 'NCHW' format (with transposition) matches the 'NHWC' format using - # the higher level interface. - def _VerifyValues(self, - tensor_in_sizes, - filter_in_sizes, - stride, - padding, - use_gpu, - data_format="NHWC"): - """Verifies the output values of the convolution function. - - Args: - tensor_in_sizes: Input tensor dimensions in - [batch, input_rows, input_cols, input_depth]. - filter_in_sizes: Filter tensor dimensions in - [filter_rows, filter_cols, input_depth, depth_multiplier]. - stride: Stride. - padding: Padding type. - use_gpu: Whether to use GPU. - data_format: The data_format of the input. "NHWC" or "NCHW". - """ - total_size_1 = 1 - total_size_2 = 1 - for s in tensor_in_sizes: - total_size_1 *= s - for s in filter_in_sizes: - total_size_2 *= s - # Initializes the input and filter tensor with numbers incrementing from 1. - x1 = [f * 1.0 for f in range(1, total_size_1 + 1)] - x2 = [f * 1.0 for f in range(1, total_size_2 + 1)] - with self.test_session(use_gpu=use_gpu) as sess: - with sess.graph._kernel_label_map({"DepthwiseConv2dNative": "neon"}): - t1 = constant_op.constant(x1, shape=tensor_in_sizes) - t1.set_shape(tensor_in_sizes) - t2 = constant_op.constant(x2, shape=filter_in_sizes) - - native_t1 = t1 - strides = [1, stride, stride, 1] - if data_format == "NCHW": - # Transpose from NWHC input to NCHW - # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] - native_t1 = array_ops.transpose(t1, [0, 3, 1, 2]) - strides = [1, 1, stride, stride] - - conv_native = nn_ops.depthwise_conv2d_native( - native_t1, - t2, - strides=strides, - data_format=data_format, - padding=padding) - - if data_format == "NCHW": - # Transpose back from NCHW to NHWC - conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1]) - - conv_interface = nn_impl.depthwise_conv2d( - t1, t2, strides=[1, stride, stride, 1], padding=padding) - - native_result = sess.run(conv_native) - interface_result = sess.run(conv_interface) - - print("depthwise conv_2d: ", tensor_in_sizes, "*", filter_in_sizes, - ", stride:", stride, ", padding: ", padding, ", max diff: ", - np.amax(np.absolute(native_result - interface_result))) - self.assertArrayNear( - np.ravel(native_result), np.ravel(interface_result), 1e-5) - self.assertShapeEqual(native_result, conv_native) - self.assertShapeEqual(native_result, conv_interface) - - def testDepthwiseConv2D(self): - for index, (input_size, filter_size, _, stride, - padding) in enumerate(ConfigsToTest()): - print("Processing ", index, "th config.") - if index == 2: - self._VerifyValues( - input_size, filter_size, stride, padding, use_gpu=True) - self._VerifyValues( - input_size, filter_size, stride, padding, use_gpu=False) - - def testDepthwiseConv2DFormat(self): - if not test.is_gpu_available(): - return - - for index, (input_size, filter_size, _, stride, - padding) in enumerate(ConfigsToTest()): - print("Processing ", index, "th config.") - self._VerifyValues( - input_size, - filter_size, - stride, - padding, - use_gpu=True, - data_format="NCHW") - -# This is testing against hand calculated results. - - def _VerifyHandValues(self, tensor_in_sizes, filter_in_sizes, stride, padding, - expected, use_gpu): - """Verifies the output values of the depthwise convolution function. - - Args: - tensor_in_sizes: Input tensor dimensions in - [batch, input_rows, input_cols, input_depth]. - filter_in_sizes: Filter tensor dimensions in - [filter_rows, filter_cols, input_depth, depth_multiplier]. - stride: Stride. - padding: Padding type. - expected: An array containing the expected operation outputs. - use_gpu: Whether to use GPU. - """ - total_size_1 = 1 - total_size_2 = 1 - for s in tensor_in_sizes: - total_size_1 *= s - for s in filter_in_sizes: - total_size_2 *= s - # Initializes the input tensor with array containing incrementing - # numbers from 1. - x1 = [f * 1.0 for f in range(1, total_size_1 + 1)] - x2 = [f * 1.0 for f in range(1, total_size_2 + 1)] - with self.test_session(use_gpu=use_gpu) as sess: - with sess.graph._kernel_label_map({"DepthwiseConv2dNative": "neon"}): - t1 = constant_op.constant(x1, shape=tensor_in_sizes) - t1.set_shape(tensor_in_sizes) - t2 = constant_op.constant(x2, shape=filter_in_sizes) - conv = nn_ops.depthwise_conv2d_native( - t1, t2, strides=[1, stride, stride, 1], padding=padding) - value = sess.run(conv) - print("value = ", value) - self.assertArrayNear(expected, np.ravel(value), 1e-5) - self.assertShapeEqual(value, conv) - - def testConv2D2x2Filter(self): - # The inputs look like this (it's a 3 x 2 matrix, each of depth 2): - # - # [ (1.0, 2.0), (3.0, 4.0), ( 5.0, 6.0) ] - # [ (7.0, 8.0), (9.0, 10.0), (11.0, 12.0) ] - # We can view this as two inputs - # - # input depth 0: - # - # [ 1.0, 3.0, 5.0 ] - # [ 7.0, 9.0, 11.0 ] - # - # input depth 1: - # - # [ 2.0, 4.0, 6.0 ] - # [ 8.0, 10.0, 12.0 ] - # - # The filter looks like this (it has two 2 x 2 patches, each generating 2 - # depths): - # - # filter #0: - # - # [ (1.0, 3.0), ( 5.0, 7.0)] - # [ (9.0, 11.0), (13.0, 15.0)] - # - # filter #1: - # - # [ ( 2.0, 4.0), ( 6.0, 8.0)] - # [ (10.0, 12.0), (14.0, 16.0)] - # - # So the outputs are: - # - # (position 0, 0: in_depth 0, output_depth 0 -- using filter #0) - # 1.0 * 1.0 + 7.0 * 9.0 + 3.0 * 5.0 + 9.0 * 13.0 = 196 - # (position 0, 0: in_depth 0, output_depth 1 -- using filter #1) - # 1.0 * 2.0 + 7.0 * 10.0 + 3.0 * 6.0 + 9.0 * 14.0 = 216 - # (position 0, 0: in_depth 1, output_depth 2 -- using filter #0) - # 2.0 * 3.0 + 8.0 * 11.0 + 4.0 * 7.0 + 10.0 * 15.0 = 272 - # (position 0, 0: in_depth 1, output_depth 3 -- using filter #1) - # 2.0 * 4.0 + 8.0 * 12.0 + 4.0 * 8.0 + 10.0 * 16.0 = 296 - # - # (position 1, 0: in_depth 0, output_depth 0 -- using filter #0) - # 3.0 * 1.0 + 9.0 * 9.0 + 5.0 * 5.0 + 11.0 * 13.0 = 252 - # (position 1, 0: in_depth 0, output_depth 1 -- using filter #1) - # 3.0 * 2.0 + 9.0 * 10.0 + 5.0 * 6.0 + 11.0 * 14.0 = 280 - # (position 1, 0: in_depth 1, output_depth 2 -- using filter #0) - # 4.0 * 3.0 + 10.0 * 11.0 + 6.0 * 7.0 + 12.0 * 15.0 = 344 - # (position 1, 0: in_depth 1, output_depth 3 -- using filter #1) - # 4.0 * 4.0 + 10.0 * 12.0 + 6.0 * 8.0 + 12.0 * 16.0 = 376 - expected_output = [196, 216, 272, 296, 252, 280, 344, 376] - self._VerifyHandValues( - tensor_in_sizes=[1, 2, 3, 2], - filter_in_sizes=[2, 2, 2, 2], - stride=1, - padding="VALID", - expected=expected_output, - use_gpu=False) - - self._VerifyHandValues( - tensor_in_sizes=[1, 2, 3, 2], - filter_in_sizes=[2, 2, 2, 2], - stride=1, - padding="VALID", - expected=expected_output, - use_gpu=True) - - -if __name__ == "__main__": - test.main() -- GitLab From c03d5cc66421b97770ee2b63595c806a90173ef4 Mon Sep 17 00:00:00 2001 From: Ali Siddiqui Date: Tue, 16 May 2017 03:01:07 +0300 Subject: [PATCH 655/697] NADAM Optimizer (#9889) * Initial commit for NADAM * Add GPU and sparse implementations and add missing arguments * Add tester * Revert changes to files made for testing * Add nadam optimizer in a class of its own * Reverse changes to adam_test.py * Reverse changes to adam.py * Actually reverse adam_test.py * Actually reverse adam.py * Delete nadam_optimizer_test.py * Create nadam_optimizer_test.py * Fix BUILD * Run buildifier on BUILD --- tensorflow/contrib/opt/BUILD | 18 ++ tensorflow/contrib/opt/__init__.py | 2 + .../opt/python/training/nadam_optimizer.py | 90 ++++++++++ .../python/training/nadam_optimizer_test.py | 158 ++++++++++++++++++ tensorflow/core/kernels/training_ops.cc | 22 ++- tensorflow/core/kernels/training_ops.h | 3 +- .../core/kernels/training_ops_gpu.cu.cc | 28 +++- tensorflow/core/ops/training_ops.cc | 12 +- 8 files changed, 318 insertions(+), 15 deletions(-) create mode 100644 tensorflow/contrib/opt/python/training/nadam_optimizer.py create mode 100644 tensorflow/contrib/opt/python/training/nadam_optimizer_test.py diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 2a8714644c..a7e910975f 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -18,6 +18,7 @@ py_library( "python/training/external_optimizer.py", "python/training/lazy_adam_optimizer.py", "python/training/moving_average_optimizer.py", + "python/training/nadam_optimizer.py", "python/training/variable_clipping_optimizer.py", ], srcs_version = "PY2AND3", @@ -106,6 +107,23 @@ py_test( ], ) +py_test( + name = "nadam_optimizer_test", + srcs = ["python/training/nadam_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "drop_stale_gradient_optimizer_test", srcs = ["python/training/drop_stale_gradient_optimizer_test.py"], diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 6cd68f29a7..be12f934a4 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import * from tensorflow.contrib.opt.python.training.external_optimizer import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.nadam_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * # pylint: enable=wildcard-import @@ -31,6 +32,7 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['DropStaleGradientOptimizer', 'ExternalOptimizerInterface', 'LazyAdamOptimizer', + 'NadamOptimizer', 'MovingAverageOptimizer', 'ScipyOptimizerInterface', 'VariableClippingOptimizer'] diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer.py b/tensorflow/contrib/opt/python/training/nadam_optimizer.py new file mode 100644 index 0000000000..07521bd4ce --- /dev/null +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer.py @@ -0,0 +1,90 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Nadam for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import training_ops +from tensorflow.python.training import adam + + +class NadamOptimizer(adam.AdamOptimizer): + """Optimizer that implements the Nadam algorithm. + + See [Dozat, T., 2015](http://cs229.stanford.edu/proj2015/054_report.pdf). + """ + + def _apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + return training_ops.apply_adam( + var, m, v, + math_ops.cast(self._beta1_power, var.dtype.base_dtype), + math_ops.cast(self._beta2_power, var.dtype.base_dtype), + math_ops.cast(self._lr_t, var.dtype.base_dtype), + math_ops.cast(self._beta1_t, var.dtype.base_dtype), + math_ops.cast(self._beta2_t, var.dtype.base_dtype), + math_ops.cast(self._epsilon_t, var.dtype.base_dtype), + grad, use_locking=self._use_locking, + use_nesterov=True).op + + def _resource_apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + return training_ops.resource_apply_adam( + var.handle, m.handle, v.handle, + math_ops.cast(self._beta1_power, grad.dtype.base_dtype), + math_ops.cast(self._beta2_power, grad.dtype.base_dtype), + math_ops.cast(self._lr_t, grad.dtype.base_dtype), + math_ops.cast(self._beta1_t, grad.dtype.base_dtype), + math_ops.cast(self._beta2_t, grad.dtype.base_dtype), + math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), + grad, use_locking=self._use_locking, + use_nesterov=True) + + def _apply_sparse_shared(self, grad, var, indices, scatter_add): + beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, "m") + m_scaled_g_values = grad * (1 - beta1_t) + m_t = state_ops.assign(m, m * beta1_t, + use_locking=self._use_locking) + with ops.control_dependencies([m_t]): + m_t = scatter_add(m, indices, m_scaled_g_values) + # m_bar = (1 - beta1) * g_t + beta1 * m_t + m_bar = m_scaled_g_values + beta1_t * m_t + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) + v = self.get_slot(var, "v") + v_scaled_g_values = (grad * grad) * (1 - beta2_t) + v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) + with ops.control_dependencies([v_t]): + v_t = scatter_add(v, indices, v_scaled_g_values) + v_sqrt = math_ops.sqrt(v_t) + var_update = state_ops.assign_sub(var, + lr * m_bar / (v_sqrt + epsilon_t), + use_locking=self._use_locking) + return control_flow_ops.group(*[var_update, m_bar, v_t]) diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py new file mode 100644 index 0000000000..3d48684f53 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py @@ -0,0 +1,158 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Nadam.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.contrib.opt.python.training import nadam_optimizer + + +def nadam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + m_bar = (1 - beta1) * g_t + beta1 * m_t + + param_t = param - alpha_t * m_bar / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class NadamOptimizerTest(test.TestCase): + + def doTestSparse(self, use_resource=False): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = nadam_optimizer.NadamOptimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Nadam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = nadam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSparse(self): + self.doTestSparse(use_resource=False) + + def testResourceSparse(self): + self.doTestSparse(use_resource=True) + + def doTestBasic(self, use_resource=False): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = nadam_optimizer.NadamOptimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Nadam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = nadam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testBasic(self): + self.doTestBasic(use_resource=False) + + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 3b2fa29693..d05cb1ecb4 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -245,12 +245,22 @@ struct ApplyAdamNonCuda { typename TTypes::ConstScalar beta1, typename TTypes::ConstScalar beta2, typename TTypes::ConstScalar epsilon, - typename TTypes::ConstFlat grad) { + typename TTypes::ConstFlat grad, + bool use_nesterov) { const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) / (T(1) - beta1_power()); + // beta1 == μ + // beta2 == ν + // v == n + // var == θ + m.device(d) += (grad - m) * (T(1) - beta1()); v.device(d) += (grad.square() - v) * (T(1) - beta2()); - var.device(d) -= (m * alpha) / (v.sqrt() + epsilon()); + if (use_nesterov) { + var.device(d) -= ((grad * (T(1) - beta1()) + beta1() * m) * alpha) / (v.sqrt() + epsilon()); + } else { + var.device(d) -= (m * alpha) / (v.sqrt() + epsilon()); + } } }; @@ -2248,6 +2258,7 @@ class ApplyAdamOp : public OpKernel { public: explicit ApplyAdamOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); } void Compute(OpKernelContext* ctx) override { @@ -2322,13 +2333,15 @@ class ApplyAdamOp : public OpKernel { v.flat(), beta1_power.scalar(), beta2_power.scalar(), lr.scalar(), beta1.scalar(), beta2.scalar(), - epsilon.scalar(), grad.flat()); + epsilon.scalar(), grad.flat(), + use_nesterov_); MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: bool use_exclusive_lock_; + bool use_nesterov_; }; using CPUDevice = Eigen::ThreadPoolDevice; @@ -2372,7 +2385,8 @@ namespace functor { typename TTypes::ConstScalar beta1, \ typename TTypes::ConstScalar beta2, \ typename TTypes::ConstScalar epsilon, \ - typename TTypes::ConstFlat grad); \ + typename TTypes::ConstFlat grad, \ + bool use_nesterov); \ extern template struct ApplyAdam; DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h index c96468b270..9c8a1e8d3d 100644 --- a/tensorflow/core/kernels/training_ops.h +++ b/tensorflow/core/kernels/training_ops.h @@ -123,7 +123,8 @@ struct ApplyAdam { typename TTypes::ConstScalar beta1, typename TTypes::ConstScalar beta2, typename TTypes::ConstScalar epsilon, - typename TTypes::ConstFlat grad); + typename TTypes::ConstFlat grad, + bool use_nesterov); }; template diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index f6acdf2422..c2563c3a49 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -109,7 +109,8 @@ struct ApplyAdam { typename TTypes::ConstScalar beta1, typename TTypes::ConstScalar beta2, typename TTypes::ConstScalar epsilon, - typename TTypes::ConstFlat grad) { + typename TTypes::ConstFlat grad, + bool use_nesterov) { Eigen::array::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; @@ -122,11 +123,26 @@ struct ApplyAdam { v + (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) * (grad.square() - v); - var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() / - (beta1_power.constant(one) - beta1_power)) - .reshape(single) - .broadcast(bcast) * - m / (epsilon.reshape(single).broadcast(bcast) + v.sqrt()); + + if (use_nesterov) { + var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() / + (beta1_power.constant(one) - beta1_power)) + .reshape(single) + .broadcast(bcast) * + (m * beta1.reshape(single).broadcast(bcast) + + (beta1.constant(one) - beta1) + .reshape(single) + .broadcast(bcast) * + grad) / (epsilon + .reshape(single) + .broadcast(bcast) + v.sqrt()); + } else { + var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() / + (beta1_power.constant(one) - beta1_power)) + .reshape(single) + .broadcast(bcast) * + m / (epsilon.reshape(single).broadcast(bcast) + v.sqrt()); + } } }; diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index 2027bf4603..6f7a007f2c 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -1004,7 +1004,7 @@ out: Same as "var". use_locking: If `True`, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. -use_nesterov: If `True`, the tensor passed to compute grad will be +use_nesterov: If `True`, the tensor passed to compute grad will be var - lr * momentum * accum, so in the end, the var you get is actually var - lr * momentum * accum. )doc"); @@ -1043,7 +1043,7 @@ out: Same as "var". use_locking: If `True`, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. -use_nesterov: If `True`, the tensor passed to compute grad will be +use_nesterov: If `True`, the tensor passed to compute grad will be var - lr * momentum * accum, so in the end, the var you get is actually var - lr * momentum * accum. )doc"); @@ -1075,7 +1075,7 @@ momentum: Momentum. Must be a scalar. use_locking: If `True`, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. -use_nesterov: If `True`, the tensor passed to compute grad will be +use_nesterov: If `True`, the tensor passed to compute grad will be var - lr * momentum * accum, so in the end, the var you get is actually var - lr * momentum * accum. )doc"); @@ -1112,7 +1112,7 @@ momentum: Momentum. Must be a scalar. use_locking: If `True`, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. -use_nesterov: If `True`, the tensor passed to compute grad will be +use_nesterov: If `True`, the tensor passed to compute grad will be var - lr * momentum * accum, so in the end, the var you get is actually var - lr * momentum * accum. )doc"); @@ -1150,6 +1150,7 @@ REGISTER_OP("ApplyAdam") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") + .Attr("use_nesterov: bool = false") .SetShapeFn([](InferenceContext* c) { return ApplyAdamShapeFn(c, false /* sparse */); }) @@ -1175,6 +1176,7 @@ out: Same as "var". use_locking: If `True`, updating of the var, m, and v tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. +use_nesterov: If `True`, uses the nesterov update. )doc"); REGISTER_OP("ResourceApplyAdam") @@ -1190,6 +1192,7 @@ REGISTER_OP("ResourceApplyAdam") .Input("grad: T") .Attr("T: numbertype") .Attr("use_locking: bool = false") + .Attr("use_nesterov: bool = false") .SetShapeFn([](InferenceContext* c) { return ApplyAdamShapeFn(c, false /* sparse */); }) @@ -1214,6 +1217,7 @@ grad: The gradient. use_locking: If `True`, updating of the var, m, and v tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. +use_nesterov: If `True`, uses the nesterov update. )doc"); static Status ApplyRMSPropShapeFn(InferenceContext* c, bool sparse) { -- GitLab From 6c6a518a95c04074d7f75fcecb9a03410386aa6b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 15 May 2017 17:23:24 -0700 Subject: [PATCH 656/697] [XLA] Add a test for a zero-element and scalar reshape operation. PiperOrigin-RevId: 156126496 --- tensorflow/compiler/xla/tests/reshape_test.cc | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 839ae42a19..c5f20b9ca1 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -67,6 +67,22 @@ XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { ComputeAndCompareR0(&builder, 1.0f, {}, zero_error_spec_); } +XLA_TEST_F(ReshapeTest, ScalarToSingleElementArray) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); + a = builder.Neg(a); + auto reshape = + builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); + + ComputeAndCompareR1(&builder, {-1.0f}, {param0_data.get()}, + zero_error_spec_); +} + XLA_TEST_F(ReshapeTest, Trivial0x3) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2FromArray2D(Array2D(0, 3)); @@ -75,6 +91,24 @@ XLA_TEST_F(ReshapeTest, Trivial0x3) { ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-05-15 +// with an incorrect result rank. +XLA_TEST_F(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr param0_literal = + LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0, 3}), "param0"); + auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); + + ComputeAndCompareR1(&builder, {}, {param0_data.get()}, + zero_error_spec_); +} + XLA_TEST_F(ReshapeTest, Trivial3x0) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2FromArray2D(Array2D(3, 0)); -- GitLab From c695f22cc5c974a66949b4c7f95d165aa34de6b3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 15 May 2017 18:08:58 -0700 Subject: [PATCH 657/697] TensorFlow: Clean ups to python_op_gen. PiperOrigin-RevId: 156130839 --- tensorflow/python/framework/python_op_gen.cc | 172 +++++++++++-------- 1 file changed, 100 insertions(+), 72 deletions(-) diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 220868168a..a3168a0088 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -21,8 +21,11 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def.pb_text.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor.pb_text.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -66,7 +69,8 @@ bool IsPythonReserved(const string& s) { "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__", "__package__", // Imports and symbols used in the generated code: - "_op_def_lib", "text_format", "op_def_pb2", "op_def_library", "ops"}); + "_text_format", "_op_def_pb2", "_common_shapes", "_op_def_registry", + "_ops", "_op_def_library"}); return kPythonReserved->count(s) > 0; } @@ -175,13 +179,12 @@ string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg, prefix = "A list of"; } } else { - prefix = strings::StrCat( - "A list with the same number of `Tensor` objects as `", - AvoidPythonReserved(*original_arg), "` of"); + prefix = strings::StrCat("A list with the same length as `", + AvoidPythonReserved(*original_arg), "` of"); } if (arg.type() != DT_INVALID) { - return strings::StrCat(prefix, " `Tensor` objects of type ", + return strings::StrCat(prefix, " `Tensor` objects with type ", TypeString(arg.type(), arg.is_ref()), "."); } else { original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr()); @@ -189,20 +192,22 @@ string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg, strings::StrAppend(&prefix, " mutable"); } if (original_arg == nullptr) { - return strings::StrCat(prefix, " `Tensor` objects of type ", - arg.type_attr(), "."); + return strings::StrCat(prefix, " `Tensor` objects with type `", + arg.type_attr(), "`."); } else if (*original_arg == arg.name()) { const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def); if (attr->has_allowed_values()) { return strings::StrCat(prefix, - " `Tensor` objects of the same type in: ", + " `Tensor` objects with the same type in: ", TypeListString(attr->allowed_values()), "."); } else { - return strings::StrCat(prefix, " `Tensor` objects of the same type."); + return strings::StrCat(prefix, + " `Tensor` objects with the same type."); } } else { - return strings::StrCat(prefix, " `Tensor` objects of the same type as ", - AvoidPythonReserved(*original_arg), "."); + return strings::StrCat(prefix, + " `Tensor` objects with the same type as `", + AvoidPythonReserved(*original_arg), "`."); } } } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) { @@ -241,19 +246,19 @@ string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg, } } -static string GetReturns(const OpDef& op_def, - const std::vector& output_type_string) { +string GetReturns(const OpDef& op_def, + const std::vector& output_type_string) { string result; DCHECK_EQ(op_def.output_arg_size(), output_type_string.size()); const int num_outs = op_def.output_arg_size(); - strings::Appendf(&result, "\n Returns:\n"); + strings::StrAppend(&result, "\n Returns:\n"); if (num_outs == 0) { - strings::Appendf(&result, " The created Operation.\n"); + strings::StrAppend(&result, " The created Operation.\n"); } else { if (num_outs == 1) { StringPiece description = op_def.output_arg(0).description(); if (ConsumeEquals(&description)) { // Skip the generated type info. - strings::Appendf(&result, "%s", Indent(4, 4, description).c_str()); + strings::StrAppend(&result, Indent(4, 4, description)); } else { // Special case of one output, don't use the name of the output unless // there is no description. @@ -272,7 +277,7 @@ static string GetReturns(const OpDef& op_def, } else if (!description.empty()) { AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */); } - strings::Appendf(&result, "%s", Indent(4, 4, desc).c_str()); + strings::StrAppend(&result, Indent(4, 4, desc)); } } else { std::vector out_names(num_outs); @@ -283,8 +288,8 @@ static string GetReturns(const OpDef& op_def, out_names[i] = strings::StrCat("output", i); } } - strings::Appendf(&result, " A tuple of `Tensor` objects (%s).\n\n", - str_util::Join(out_names, ", ").c_str()); + strings::StrAppend(&result, " A tuple of `Tensor` objects (", + str_util::Join(out_names, ", "), ").\n\n"); for (int i = 0; i < num_outs; ++i) { string desc = strings::StrCat(out_names[i], ": "); StringPiece description = op_def.output_arg(i).description(); @@ -307,7 +312,7 @@ static string GetReturns(const OpDef& op_def, strings::StrAppend(&desc, type); } } - strings::Appendf(&result, "%s", Indent(4, 6, desc).c_str()); + strings::StrAppend(&result, Indent(4, 6, desc)); } } } @@ -337,6 +342,10 @@ string ShapeToPython(const TensorShapeProto& shape) { return python; } +string TensorToPython(const TensorProto& proto) { + return ProtoShortDebugString(proto); +} + string AttrListToPython(const AttrValue& value) { string ret; if (value.list().s_size() > 0) { @@ -369,6 +378,16 @@ string AttrListToPython(const AttrValue& value) { if (i > 0) strings::StrAppend(&ret, ", "); strings::StrAppend(&ret, ShapeToPython(value.list().shape(i))); } + } else if (value.list().tensor_size() > 0) { + for (int i = 0; i < value.list().tensor_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, TensorToPython(value.list().tensor(i))); + } + } else if (value.list().func_size() > 0) { + for (int i = 0; i < value.list().func_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, StringToPython(value.list().func(i).name())); + } } return ret; } @@ -386,8 +405,14 @@ string AttrValueToPython(const string& type, const AttrValue& value) { return DataTypeToPython(value.type()); } else if (type == "shape") { return ShapeToPython(value.shape()); - } else { + } else if (type == "tensor") { + return TensorToPython(value.tensor()); + } else if (type == "func") { + return StringToPython(value.func().name()); + } else if (StringPiece(type).starts_with("list(")) { return strings::StrCat("[", AttrListToPython(value), "]"); + } else { + return "?"; } } @@ -417,7 +442,7 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { // defaults. std::vector args_no_default; // The parameters with defaults (these have to be listed after those without). - // No input args are included, just attrs and the graph ("g") parameter. + // No input args are included, just attrs. std::vector args_with_defaults; for (int i = 0; i < op_def.input_arg_size(); ++i) { const auto& arg(op_def.input_arg(i)); @@ -448,8 +473,7 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { // those with defaults go at the end. std::vector attrs; // Get the attrs in the order we want by taking the attrs without defaults - // from the end of args_no_default, and adding args_no_default (before - // "g" gets added to args_no_default, so it only has attrs). + // from the end of args_no_default, and adding args_no_default. attrs.reserve(args_no_default.size() - op_def.input_arg_size() + args_with_defaults.size()); attrs.insert(attrs.end(), args_no_default.begin() + op_def.input_arg_size(), @@ -472,51 +496,51 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { strings::StrAppend(¶meters, param, "=None"); param_names.push_back(param); } - const bool has_args = args_no_default.size() + args_with_defaults.size() > 0; const string lower_op_name = strings::StrCat(is_hidden ? "_" : "", op_name); - // Prepare the list of output names const int num_outs = op_def.output_arg_size(); - std::vector out_names(num_outs); - for (int i = 0; i < num_outs; ++i) { - if (!op_def.output_arg(i).name().empty()) { - out_names[i] = op_def.output_arg(i).name(); - } else { - out_names[i] = strings::StrCat("output", i); - } - } - string out_names_list = - strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]"); - - // Provide the output names as a Python list - string lower_op_name_outputs = - strings::StrCat("_", lower_op_name, "_outputs"); - const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = "); - strings::Appendf( - &result, "%s\n", - WordWrap(outputs_prefix, out_names_list, kRightMargin).c_str()); - strings::Appendf(&result, "\n\n"); - // Prepare a NamedTuple type to hold the outputs, if there are multiple if (num_outs > 1) { - const string tuple_type_prefix = strings::StrCat( - "_", op_def.name(), "Output = _collections.namedtuple("); + // Prepare the list of output names + std::vector out_names(num_outs); + for (int i = 0; i < num_outs; ++i) { + if (!op_def.output_arg(i).name().empty()) { + out_names[i] = op_def.output_arg(i).name(); + } else { + out_names[i] = strings::StrCat("output", i); + } + } + string out_names_list = + strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]"); + + // Provide the output names as a Python list + string lower_op_name_outputs = + strings::StrCat("_", lower_op_name, "_outputs"); + const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = "); + strings::StrAppend(&result, "\n", + WordWrap(outputs_prefix, out_names_list, kRightMargin), + "\n"); + + strings::StrAppend(&result, "_", op_def.name(), + "Output = _collections.namedtuple(\n"); + const string tuple_type_prefix = " "; const string tuple_type_suffix = strings::StrCat( "\"", op_def.name(), "\", ", lower_op_name_outputs, ")"); - strings::Appendf( - &result, "%s\n", - WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin).c_str()); - strings::Appendf(&result, "\n\n"); + strings::StrAppend( + &result, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin), + "\n\n"); } + strings::StrAppend(&result, "\n"); // Print: def Function(parameters): const string def_prefix = strings::StrCat("def ", lower_op_name, "("); + const bool has_args = args_no_default.size() + args_with_defaults.size() > 0; const string def_suffix = strings::StrCat(parameters, has_args ? ", " : "", "name=None):"); - strings::Appendf(&result, "%s\n", - WordWrap(def_prefix, def_suffix, kRightMargin).c_str()); + strings::StrAppend(&result, WordWrap(def_prefix, def_suffix, kRightMargin), + "\n"); // Format the Op's descriptions so that it can be a Python docstring. string comment; @@ -529,7 +553,7 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { } } - strings::Appendf(&result, " r\"\"\"%s\n Args:\n", comment.c_str()); + strings::StrAppend(&result, " r\"\"\"", comment, "\n Args:\n"); // Inputs for (int i = 0; i < op_def.input_arg_size(); ++i) { @@ -545,7 +569,7 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { if (!description.empty()) { AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */); } - strings::Appendf(&result, "%s", Indent(4, 6, desc).c_str()); + strings::StrAppend(&result, Indent(4, 6, desc)); } // Attrs @@ -567,6 +591,10 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { {"shape", "`tf.TensorShape` or list of `ints`"}, {"list(shape)", "list of shapes (each a `tf.TensorShape` or list of `ints`)"}, + {"tensor", "`tf.TensorProto`"}, + {"list(tensor)", "list of `tf.TensorProto` objects"}, + {"func", "function decorated with @Defun"}, + {"list(func)", "list of functions decorated with @Defun"}, }; for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) { if (attr.type() == kAttrTypeName[i][0]) { @@ -610,14 +638,15 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { AppendWithinWidth(&desc, attr.description(), kRightMargin - 4 /* indent */); } - strings::Appendf(&result, "%s", Indent(4, 6, desc).c_str()); + strings::StrAppend(&result, Indent(4, 6, desc)); } - strings::Appendf(&result, " name: A name for the operation (optional).\n"); + strings::StrAppend(&result, + " name: A name for the operation (optional).\n"); std::vector output_type_string; - output_type_string.reserve(op_def.output_arg_size()); - for (int i = 0; i < op_def.output_arg_size(); ++i) { + output_type_string.reserve(num_outs); + for (int i = 0; i < num_outs; ++i) { output_type_string.push_back( ArgTypeName(op_def, op_def.output_arg(i), inferred_attrs, true)); } @@ -630,19 +659,18 @@ string GetPythonOp(const OpDef& op_def, bool is_hidden, const string& op_name) { } strings::StrAppend(&return_args, "name=name)"); - strings::Appendf(&result, " \"\"\"\n%s\n", - // Wrap the arguments, and indent to the (. - WordWrap(return_prefix, return_args, kRightMargin).c_str()); + strings::StrAppend(&result, " \"\"\"\n", + // Wrap the arguments, and indent to the (. + WordWrap(return_prefix, return_args, kRightMargin), "\n"); if (num_outs <= 1) { - strings::Appendf(&result, " return result\n"); + strings::StrAppend(&result, " return result\n"); } else { - string return_tuple = - strings::StrCat(" return _", op_def.name(), "Output._make(result)\n"); - strings::Appendf(&result, "%s", return_tuple.c_str()); + strings::StrAppend(&result, " return _", op_def.name(), + "Output._make(result)\n"); } + strings::StrAppend(&result, "\n\n"); - strings::Appendf(&result, "\n\n"); return result; } @@ -651,7 +679,7 @@ string GetPythonOps(const OpList& ops, const std::vector& hidden_ops, string result; // Header // TODO(josh11b): Mention the library for which wrappers are being generated. - strings::Appendf(&result, R"("""Python wrappers around Brain. + strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit. """ @@ -699,8 +727,8 @@ from tensorflow.python.framework import op_def_library as _op_def_library GetPythonOp(op_def, is_hidden, lower_case_name)); if (!require_shapes) { - strings::Appendf(&result, "_ops.RegisterShape(\"%s\")(None)\n", - op_def.name().c_str()); + strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(), + "\")(None)\n"); } auto added = out->Add(); @@ -722,7 +750,7 @@ _InitOpDefLibrary.op_list_ascii = """%s""" _op_def_lib = _InitOpDefLibrary() )", - cleaned_ops.DebugString().c_str()); + ProtoDebugString(cleaned_ops).c_str()); return result; } -- GitLab From e7a0308e61121338b5a86998b610abbaa2d8c372 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Mon, 15 May 2017 18:19:43 -0700 Subject: [PATCH 658/697] [TF RNN] Fix recently introduced bug if the state size is a scalar. Turns out tf.concat((int32_tensor, []), 0) thinks the empty list [] is floating point. o_O. PiperOrigin-RevId: 156131674 --- tensorflow/python/ops/rnn_cell_impl.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index cc2d9d037c..10d23eb09f 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -24,6 +24,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -66,8 +68,9 @@ def _concat(prefix, suffix, static=False): "but saw tensor: %s" % p) else: p = tensor_shape.as_shape(prefix) - p = p.as_list() if p.ndims is not None else None - p_static = p + p_static = p.as_list() if p.ndims is not None else None + p = (constant_op.constant(p.as_list(), dtype=dtypes.int32) + if p.is_fully_defined() else None) if isinstance(suffix, ops.Tensor): s = suffix s_static = tensor_util.constant_value(suffix) @@ -78,8 +81,9 @@ def _concat(prefix, suffix, static=False): "but saw tensor: %s" % s) else: s = tensor_shape.as_shape(suffix) - s = s.as_list() if s.ndims is not None else None - s_static = s + s_static = s.as_list() if s.ndims is not None else None + s = (constant_op.constant(s.as_list(), dtype=dtypes.int32) + if s.is_fully_defined() else None) if static: shape = tensor_shape.as_shape(p_static).concatenate(s_static) -- GitLab From 8962a1729a3662dd6d4a38935aaad61db79aa513 Mon Sep 17 00:00:00 2001 From: Andrew Harp Date: Mon, 15 May 2017 18:31:15 -0700 Subject: [PATCH 659/697] Android demo: remove unecessary loadlibrary call from TensorFlowImageClassifier PiperOrigin-RevId: 156132465 --- .../src/org/tensorflow/demo/TensorFlowImageClassifier.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java index f660178ebe..5756bd6b64 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java @@ -32,10 +32,6 @@ import org.tensorflow.contrib.android.TensorFlowInferenceInterface; /** A classifier specialized to label images using TensorFlow. */ public class TensorFlowImageClassifier implements Classifier { - static { - System.loadLibrary("tensorflow_demo"); - } - private static final String TAG = "TensorFlowImageClassifier"; // Only return this many results with at least this confidence. -- GitLab From 31c53c836232ad6d3caaa5bc7d42604667b127af Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 15 May 2017 19:13:11 -0700 Subject: [PATCH 660/697] Guarded whether the graph dashboard should request health pills on whether the feature is enabled, whether the user has toggled the feature on, and whether there is a dataset to request health pills for. PiperOrigin-RevId: 156135067 --- .../tf-graph-dashboard.html | 32 ++++++++----------- .../tf-graph-dashboard.html | 32 ++++++++----------- 2 files changed, 28 insertions(+), 36 deletions(-) diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html b/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html index bfc52a0a44..891905e7c4 100644 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html +++ b/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html @@ -155,7 +155,7 @@ TF.Dashboard.TfGraphDashboard = Polymer({ 'node-toggle-expand': '_handleNodeToggleExpand', }, observers: [ - '_maybeFetchHealthPillsAtSpecificStep(allStepsModeEnabled, specificHealthPillStep)', + '_maybeFetchHealthPills(allStepsModeEnabled, specificHealthPillStep)', '_maybeInitializeDashboard(backend, _isAttached)', ], attached: function() { @@ -165,19 +165,16 @@ TF.Dashboard.TfGraphDashboard = Polymer({ this.set('_isAttached', false); }, reload: function() { - if (!this.debuggerDataEnabled || - !this.healthPillsToggledOn || - !this._renderHierarchy || - this._datasetsEmpty(this._datasets)) { - // Do not load debugger data if the feature is disabled, if the user toggled off the feature, - // or if the graph itself has not loaded yet. We need the graph to load so that we know which - // nodes to request health pills for. - return; - } - - // Request debugger data on graph reloads, but do not re-request the graph itself. The graph - // would not change across reloads. - this._requestHealthPills(); + this._maybeFetchHealthPills(); + }, + _shouldRequestHealthPills: function() { + // Do not load debugger data if the feature is disabled, if the user toggled off the feature, + // or if the graph itself has not loaded yet. We need the graph to load so that we know which + // nodes to request health pills for. + return this.debuggerDataEnabled && + this.healthPillsToggledOn && + this._renderHierarchy && + !this._datasetsEmpty(this._datasets); }, _maybeInitializeDashboard: function(backend, isAttached) { if (this._initialized || !backend || !isAttached) { @@ -279,7 +276,7 @@ TF.Dashboard.TfGraphDashboard = Polymer({ }, _handleNodeToggleExpand: function() { // Nodes were toggled. We may need to request health pills for more nodes. - this._requestHealthPills(); + this._maybeFetchHealthPills(); }, _healthPillsToggledOnChanged: function(healthPillsToggledOn) { if (healthPillsToggledOn) { @@ -291,9 +288,8 @@ TF.Dashboard.TfGraphDashboard = Polymer({ } }, // Fetch health pills for a specific step if applicable. - _maybeFetchHealthPillsAtSpecificStep: function(allStepsModeEnabled, specificHealthPillStep) { - if (!this._renderHierarchy) { - // The graph is not ready yet. + _maybeFetchHealthPills: function() { + if (!this._shouldRequestHealthPills()) { return; } diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/tf-graph-dashboard.html b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/tf-graph-dashboard.html index bfc52a0a44..891905e7c4 100644 --- a/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/tf-graph-dashboard.html +++ b/tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/tf-graph-dashboard.html @@ -155,7 +155,7 @@ TF.Dashboard.TfGraphDashboard = Polymer({ 'node-toggle-expand': '_handleNodeToggleExpand', }, observers: [ - '_maybeFetchHealthPillsAtSpecificStep(allStepsModeEnabled, specificHealthPillStep)', + '_maybeFetchHealthPills(allStepsModeEnabled, specificHealthPillStep)', '_maybeInitializeDashboard(backend, _isAttached)', ], attached: function() { @@ -165,19 +165,16 @@ TF.Dashboard.TfGraphDashboard = Polymer({ this.set('_isAttached', false); }, reload: function() { - if (!this.debuggerDataEnabled || - !this.healthPillsToggledOn || - !this._renderHierarchy || - this._datasetsEmpty(this._datasets)) { - // Do not load debugger data if the feature is disabled, if the user toggled off the feature, - // or if the graph itself has not loaded yet. We need the graph to load so that we know which - // nodes to request health pills for. - return; - } - - // Request debugger data on graph reloads, but do not re-request the graph itself. The graph - // would not change across reloads. - this._requestHealthPills(); + this._maybeFetchHealthPills(); + }, + _shouldRequestHealthPills: function() { + // Do not load debugger data if the feature is disabled, if the user toggled off the feature, + // or if the graph itself has not loaded yet. We need the graph to load so that we know which + // nodes to request health pills for. + return this.debuggerDataEnabled && + this.healthPillsToggledOn && + this._renderHierarchy && + !this._datasetsEmpty(this._datasets); }, _maybeInitializeDashboard: function(backend, isAttached) { if (this._initialized || !backend || !isAttached) { @@ -279,7 +276,7 @@ TF.Dashboard.TfGraphDashboard = Polymer({ }, _handleNodeToggleExpand: function() { // Nodes were toggled. We may need to request health pills for more nodes. - this._requestHealthPills(); + this._maybeFetchHealthPills(); }, _healthPillsToggledOnChanged: function(healthPillsToggledOn) { if (healthPillsToggledOn) { @@ -291,9 +288,8 @@ TF.Dashboard.TfGraphDashboard = Polymer({ } }, // Fetch health pills for a specific step if applicable. - _maybeFetchHealthPillsAtSpecificStep: function(allStepsModeEnabled, specificHealthPillStep) { - if (!this._renderHierarchy) { - // The graph is not ready yet. + _maybeFetchHealthPills: function() { + if (!this._shouldRequestHealthPills()) { return; } -- GitLab From a26609596b2796bb3223c74587ce4d0e0f919ea9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 15 May 2017 20:37:25 -0700 Subject: [PATCH 661/697] PiperOrigin-RevId: 156139276 --- tensorflow/contrib/layers/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index fe661a5625..03af377149 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -304,6 +304,7 @@ py_test( py_test( name = "embedding_ops_test", size = "small", + timeout = "moderate", srcs = ["python/layers/embedding_ops_test.py"], srcs_version = "PY2AND3", deps = [ -- GitLab From 014a4c78af79568a0480f40a42f9869fd13fa9ee Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 May 2017 01:04:59 -0700 Subject: [PATCH 662/697] Add example to dynamic_partition and dynamic_stitch, that shows how they interact. PiperOrigin-RevId: 156153465 --- tensorflow/core/ops/data_flow_ops.cc | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index 8bdbc7e135..c80ff983cf 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -101,6 +101,8 @@ For example: outputs[1] = [30, 40] ``` +See `dynamic_stitch` for an example on how to merge partitions back. +
@@ -189,6 +191,24 @@ For example: [51, 52], [61, 62]] ``` +This method can be used to merge partitions created by `dynamic_partition` +as illustrated on the following example: + +```python + # Apply function (increments x_i) on elements for which a certain condition + # apply (x_i != -1 in this example). + x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) + condition_mask=tf.not_equal(x,tf.constant(-1.)) + partitioned_data = tf.dynamic_partition( + x, tf.cast(condition_mask, tf.int32) , 2) + partitioned_data[1] = partitioned_data[1] + 1.0 + condition_indices = tf.dynamic_partition( + tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) + x = tf.dynamic_stitch(condition_indices, partitioned_data) + # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain + # unchanged. +``` +
-- GitLab From d72e92967f089f34599d98308cdb6a67d4c2db04 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 May 2017 01:07:41 -0700 Subject: [PATCH 663/697] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 156153612 --- tensorflow/go/op/wrappers.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index bc915789da..9c67c6cd4a 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -4967,6 +4967,24 @@ func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow // [51, 52], [61, 62]] // ``` // +// This method can be used to merge partitions created by `dynamic_partition` +// as illustrated on the following example: +// +// ```python +// # Apply function (increments x_i) on elements for which a certain condition +// # apply (x_i != -1 in this example). +// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) +// condition_mask=tf.not_equal(x,tf.constant(-1.)) +// partitioned_data = tf.dynamic_partition( +// x, tf.cast(condition_mask, tf.int32) , 2) +// partitioned_data[1] = partitioned_data[1] + 1.0 +// condition_indices = tf.dynamic_partition( +// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) +// x = tf.dynamic_stitch(condition_indices, partitioned_data) +// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain +// # unchanged. +// ``` +// //
// //
@@ -21386,6 +21404,8 @@ func Sum(scope *Scope, input tf.Output, reduction_indices tf.Output, optional .. // outputs[1] = [30, 40] // ``` // +// See `dynamic_stitch` for an example on how to merge partitions back. +// //
// //
-- GitLab From f1ce39d42b7f2429bfdedccad4dc5276fef03460 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 May 2017 01:48:43 -0700 Subject: [PATCH 664/697] Update ops-related pbtxt files. PiperOrigin-RevId: 156156322 --- tensorflow/core/ops/ops.pbtxt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 8079555128..b0304f2486 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -6538,7 +6538,7 @@ op { type: "type" } summary: "Partitions `data` into `num_partitions` tensors using indices from `partitions`." - description: "For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]`\nbecomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i`\nare placed in `outputs[i]` in lexicographic order of `js`, and the first\ndimension of `outputs[i]` is the number of entries in `partitions` equal to `i`.\nIn detail,\n\n```python\n outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:]\n\n outputs[i] = pack([data[js, ...] for js if partitions[js] == i])\n```\n\n`data.shape` must start with `partitions.shape`.\n\nFor example:\n\n```python\n # Scalar partitions.\n partitions = 1\n num_partitions = 2\n data = [10, 20]\n outputs[0] = [] # Empty with shape [0, 2]\n outputs[1] = [[10, 20]]\n\n # Vector partitions.\n partitions = [0, 0, 1, 1, 0]\n num_partitions = 2\n data = [10, 20, 30, 40, 50]\n outputs[0] = [10, 20, 50]\n outputs[1] = [30, 40]\n```\n\n
\n\n
" + description: "For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]`\nbecomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i`\nare placed in `outputs[i]` in lexicographic order of `js`, and the first\ndimension of `outputs[i]` is the number of entries in `partitions` equal to `i`.\nIn detail,\n\n```python\n outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:]\n\n outputs[i] = pack([data[js, ...] for js if partitions[js] == i])\n```\n\n`data.shape` must start with `partitions.shape`.\n\nFor example:\n\n```python\n # Scalar partitions.\n partitions = 1\n num_partitions = 2\n data = [10, 20]\n outputs[0] = [] # Empty with shape [0, 2]\n outputs[1] = [[10, 20]]\n\n # Vector partitions.\n partitions = [0, 0, 1, 1, 0]\n num_partitions = 2\n data = [10, 20, 30, 40, 50]\n outputs[0] = [10, 20, 50]\n outputs[1] = [30, 40]\n```\n\nSee `dynamic_stitch` for an example on how to merge partitions back.\n\n
\n\n
" } op { name: "DynamicStitch" @@ -6567,7 +6567,7 @@ op { type: "type" } summary: "Interleave the values from the `data` tensors into a single tensor." - description: "Builds a merged tensor such that\n\n```python\n merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...]\n```\n\nFor example, if each `indices[m]` is scalar or vector, we have\n\n```python\n # Scalar indices:\n merged[indices[m], ...] = data[m][...]\n\n # Vector indices:\n merged[indices[m][i], ...] = data[m][i, ...]\n```\n\nEach `data[i].shape` must start with the corresponding `indices[i].shape`,\nand the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we\nmust have `data[i].shape = indices[i].shape + constant`. In terms of this\n`constant`, the output shape is\n\n merged.shape = [max(indices)] + constant\n\nValues are merged in order, so if an index appears in both `indices[m][i]` and\n`indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the\nmerged result.\n\nFor example:\n\n```python\n indices[0] = 6\n indices[1] = [4, 1]\n indices[2] = [[5, 2], [0, 3]]\n data[0] = [61, 62]\n data[1] = [[41, 42], [11, 12]]\n data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]\n merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42],\n [51, 52], [61, 62]]\n```\n\n
\n\n
" + description: "Builds a merged tensor such that\n\n```python\n merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...]\n```\n\nFor example, if each `indices[m]` is scalar or vector, we have\n\n```python\n # Scalar indices:\n merged[indices[m], ...] = data[m][...]\n\n # Vector indices:\n merged[indices[m][i], ...] = data[m][i, ...]\n```\n\nEach `data[i].shape` must start with the corresponding `indices[i].shape`,\nand the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we\nmust have `data[i].shape = indices[i].shape + constant`. In terms of this\n`constant`, the output shape is\n\n merged.shape = [max(indices)] + constant\n\nValues are merged in order, so if an index appears in both `indices[m][i]` and\n`indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the\nmerged result.\n\nFor example:\n\n```python\n indices[0] = 6\n indices[1] = [4, 1]\n indices[2] = [[5, 2], [0, 3]]\n data[0] = [61, 62]\n data[1] = [[41, 42], [11, 12]]\n data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]\n merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42],\n [51, 52], [61, 62]]\n```\n\nThis method can be used to merge partitions created by `dynamic_partition`\nas illustrated on the following example:\n\n```python\n # Apply function (increments x_i) on elements for which a certain condition\n # apply (x_i != -1 in this example).\n x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4])\n condition_mask=tf.not_equal(x,tf.constant(-1.))\n partitioned_data = tf.dynamic_partition(\n x, tf.cast(condition_mask, tf.int32) , 2)\n partitioned_data[1] = partitioned_data[1] + 1.0\n condition_indices = tf.dynamic_partition(\n tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2)\n x = tf.dynamic_stitch(condition_indices, partitioned_data)\n # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain\n # unchanged.\n```\n\n
\n\n
" } op { name: "EditDistance" -- GitLab From 7266ea5133be29660b4b7a360ee874c758fc1cd2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 May 2017 03:19:28 -0700 Subject: [PATCH 665/697] Generalize layout assignment for bitcast reshapes. Currently, only special cases are detected where we can assign a layout so that the reshape is a bitcast. This CL generalizes this and if it is possible at all, a layout will be assigned so that the reshape is a bitcast. PiperOrigin-RevId: 156162657 --- tensorflow/compiler/xla/BUILD | 2 +- .../compiler/xla/service/layout_assignment.cc | 113 +++++++------- .../xla/service/layout_assignment_test.cc | 10 +- tensorflow/compiler/xla/shape_util.cc | 139 ++++++++++++++++++ tensorflow/compiler/xla/shape_util.h | 10 ++ tensorflow/compiler/xla/shape_util_test.cc | 54 +++++++ 6 files changed, 259 insertions(+), 69 deletions(-) diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 65d4528421..de09d4b23f 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -216,7 +216,7 @@ cc_test( ":test_helpers", ":types", ":util", - "//tensorflow/core:test", + ":xla_data_proto", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index d413621cfe..bf279f1b2c 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -659,44 +659,6 @@ LayoutAssignment::LayoutAssignment(ComputationLayout* entry_computation_layout) } } -namespace { - -// Given a pemutation of `{0, 1, ..., n}` `indices`, returns a permutation of -// `{0, 1, ..., n - to_delete.size() + to_insert.size()}` by deleting the -// indices `to_delete` wherever in `indices` they are, and inserting the indices -// `to_insert` arbitrarily at the back. -tensorflow::protobuf::RepeatedField -DeleteAndInsertIndices( - std::vector to_delete, std::vector to_insert, - tensorflow::protobuf::RepeatedField indices) { - std::sort(to_delete.begin(), to_delete.end(), std::greater()); - std::sort(to_insert.begin(), to_insert.end(), std::less()); - for (auto index : to_delete) { - auto i = indices.begin(); - while (i != indices.end()) { - if (*i == index) { - i = indices.erase(i); - } else { - if (*i > index) { - (*i)--; - } - ++i; - } - } - } - for (auto index : to_insert) { - for (auto i = indices.begin(); i != indices.end(); ++i) { - if (*i >= index) { - (*i)++; - } - } - indices.Add(index); - } - return indices; -} - -} // namespace - std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Layout& output_layout, const HloInstruction* instruction, int64 operand_no) { @@ -720,21 +682,32 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( } if (instruction->opcode() == HloOpcode::kReshape) { - // Pick the operand layout that makes the reshape a bitcast. If the reshape - // only inserts or deletes degenerate dimensions, we can easily compute the - // desired layout by accordingly inserting and deleting the elements in the - // minor-to-major list. - bool merely_inserts_or_deletes_1_sized_dims; - std::vector inserted_indices, deleted_indices; - std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices, - inserted_indices) = - instruction->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); - if (merely_inserts_or_deletes_1_sized_dims) { - Layout operand_layout = LayoutUtil::MakeLayout( - AsInt64Slice(DeleteAndInsertIndices(inserted_indices, deleted_indices, - output_layout.minor_to_major()))); + // Prefer the operand layout that makes the reshape an bitcast. If any + // dimension bound is 1 in the operand shape, there may be several such + // layouts. So if 'output_layout' is a MajorToMinor layout, try if the + // reshape is a bitcast when using the same layout. This may avoid copy + // operations. + const Shape& output_shape = instruction->shape(); + Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout( + output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), + AsInt64Slice(output_layout.minor_to_major())); + const Shape& operand_shape = operand->shape(); + if (LayoutUtil::IsMonotonicWithDim0Major(output_layout)) { + Shape operand_shape_with_layout = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + operand_shape.element_type(), + AsInt64Slice(operand_shape.dimensions())); + if (ShapeUtil::ReshapeIsBitcast(operand_shape_with_layout, + output_shape_with_layout)) { + return MakeUnique(operand_shape_with_layout.layout()); + } + } + auto aligned_operand_shape = + ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape); + if (aligned_operand_shape) { + auto operand_layout = aligned_operand_shape.value().layout(); TF_CHECK_OK( - LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); + LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape)); return MakeUnique(operand_layout); } } @@ -769,18 +742,32 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } if (user->opcode() == HloOpcode::kReshape) { - // Pick the user layout that makes the reshape a bitcast. - bool merely_inserts_or_deletes_1_sized_dims; - std::vector inserted_indices, deleted_indices; - std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices, - inserted_indices) = - user->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); - if (merely_inserts_or_deletes_1_sized_dims) { - Layout user_layout = LayoutUtil::MakeLayout(AsInt64Slice( - DeleteAndInsertIndices(deleted_indices, inserted_indices, - operand_layout.minor_to_major()))); + // Prefer the user layout that makes the reshape an bitcast. If any + // dimension bound is 1 in the user shape, there may be several such + // layouts. So if 'operand_layout' is a MajorToMinor layout, try if the + // reshape is a bitcast when using the same layout. This may avoid copy + // operations. + Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout( + operand->shape().element_type(), + AsInt64Slice(operand->shape().dimensions()), + AsInt64Slice(operand_layout.minor_to_major())); + const Shape& output_shape = user->shape(); + if (LayoutUtil::IsMonotonicWithDim0Major(operand_layout)) { + Shape output_shape_with_layout = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + output_shape.element_type(), + AsInt64Slice(output_shape.dimensions())); + if (ShapeUtil::ReshapeIsBitcast(output_shape_with_layout, + operand_shape_with_layout)) { + return MakeUnique(output_shape_with_layout.layout()); + } + } + auto aligned_user_shape = + ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape); + if (aligned_user_shape) { + auto user_layout = aligned_user_shape.value().layout(); TF_CHECK_OK( - LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); + LayoutUtil::ValidateLayoutForShape(user_layout, output_shape)); return MakeUnique(user_layout); } } diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index dd72566ac0..c6df9839c3 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -319,7 +319,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { // param -> log -> reshape -> tanh auto builder = HloComputation::Builder(TestName()); Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1}); - Shape bshape = ShapeUtil::MakeShape(F32, {2, 1, 3}); + Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2}); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, ashape, "param")); auto log = builder.AddInstruction( @@ -334,8 +334,8 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { Shape ashape_with_layout(ashape); Shape bshape_with_layout(bshape); - *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); - *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2}); + *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3}); + *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0}); ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_parameter_layout(0) = @@ -345,12 +345,12 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); - EXPECT_LT(PositionInContainer(log_minor_to_major, 1), + EXPECT_GT(PositionInContainer(log_minor_to_major, 1), PositionInContainer(log_minor_to_major, 2)); auto reshape_minor_to_major = AsInt64Slice(reshape->shape().layout().minor_to_major()); - EXPECT_LT(PositionInContainer(reshape_minor_to_major, 0), + EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0), PositionInContainer(reshape_minor_to_major, 2)); } diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index d3949918c8..2b32b78f0b 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -1022,6 +1023,144 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, check_input_unit_indices(output_shape, input_shape); } +/* static */ tensorflow::gtl::optional ShapeUtil::AlignLayouts( + const Shape& input_shape, const Shape& output_shape) { + int64 input_rank = ShapeUtil::Rank(input_shape); + int64 output_rank = ShapeUtil::Rank(output_shape); + + // First, calculate an alignment of the dimensions. A consecutive sequence of + // input dimensions and output dimensions belong to the same alignment part if + // the products of their dimension bounds are the same. In the easiest case, + // an alignment part consists of one input dimension and one output dimension + // which both have the same dimension bound. An alignment part specifies which + // dimensions need to be kept together in a physical layout if we want a + // reshape to be a bitcast. The order of the alignment parts is defined by the + // physical layout of the input shape, so when we construct the layout for the + // output shape we just process the alignment parts in this order, and then + // layout the dimensions belonging to each part in descending (major to minor) + // order. + + // Stores the input and output dimension numbers where each alignment part + // starts. + std::vector> alignment; + alignment.push_back({0, 0}); + + // Stores a mapping from the input dimension to the alignment part it belongs + // to. + std::vector dimension_to_alignment_index(input_rank); + int64 input_dimension_product = 1, output_dimension_product = 1; + for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) { + // Check if we have reached the end of an alignment part. + if (input_dimension_product == output_dimension_product && + input_dimension_product > 1) { + alignment.push_back({i, j}); + input_dimension_product = output_dimension_product = 1; + } + if (input_dimension_product < output_dimension_product || + j == output_rank) { + if (i == input_rank) { + return tensorflow::gtl::nullopt; + } + dimension_to_alignment_index[i] = alignment.size() - 1; + input_dimension_product *= input_shape.dimensions(i); + ++i; + } else { + output_dimension_product *= output_shape.dimensions(j); + ++j; + } + } + if (input_dimension_product != output_dimension_product) { + return tensorflow::gtl::nullopt; + } + // We also need to store an end element so that we know where the last + // alignment part ends. + alignment.push_back({input_rank, output_rank}); + + // Now check if the physical layout can potentially be aligned to the output + // shape by changing the physical layout of the output shape. We need to check + // that all dimension numbers that belong to the same alignment part appear + // consecutively, and are in descending order. However we can ignore any + // trivial dimension bounds of 1, because they can be placed anywhere. + auto input_dimension_numbers = input_shape.layout().minor_to_major(); + std::vector output_layout; + output_layout.reserve(output_rank); + for (int64 i = 0; i < input_rank;) { + int64 current_dimension_number = input_dimension_numbers[i]; + + // Skip trivial dimensions with a bound of 1. + if (input_shape.dimensions(current_dimension_number) == 1) { + ++i; + continue; + } + + // Calculate the number of non-trivial dimension bounds in the input shape + // belonging to the current alignment part. + const int64 current_alignment_index = + dimension_to_alignment_index[current_dimension_number]; + // Because of the special end element that we added, we can be sure that + // 'current_alignment_index' is < alignment.size() - 1. + CHECK_LT(current_alignment_index, alignment.size() - 1); + int64 num_non_trivial_dimensions_in_alignment_part = 0; + for (int64 j = alignment[current_alignment_index].first; + j < alignment[current_alignment_index + 1].first; ++j) { + if (input_shape.dimensions(j) != 1) { + ++num_non_trivial_dimensions_in_alignment_part; + } + } + + // Check that the following 'num_non_trivial_dimensions_in_alignment_part' + // dimension numbers (ignoring dimension numbers with dimension bound 1) are + // in descending order and belong to the current alignment part. + for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part; + ++i, ++j) { + if (i == input_rank) { + return tensorflow::gtl::nullopt; + } + // Skip trivial dimensions with a bound of 1. + if (input_shape.dimensions(input_dimension_numbers[i]) == 1) { + --j; + continue; + } + // If the current dimension number belongs to a different alignment part, + // or the dimension numbers are not in descending order, we can return + // early. + if (dimension_to_alignment_index[input_dimension_numbers[i]] != + current_alignment_index || + input_dimension_numbers[i] > current_dimension_number) { + return tensorflow::gtl::nullopt; + } + current_dimension_number = input_dimension_numbers[i]; + } + + // The output dimension numbers that belong to the current alignment part + // need to appear in the same descending order as in the input. Again, we + // can skip dimensions with a bound of 1. + for (int64 j = alignment[current_alignment_index + 1].second - 1; + j >= alignment[current_alignment_index].second; --j) { + if (output_shape.dimensions(j) != 1) { + output_layout.push_back(j); + } + } + } + // Now add all the dimensions with dimension bound 1 at the end of + // 'output_layout'. + for (int64 i = 0; i < output_rank; ++i) { + if (output_shape.dimensions(i) == 1) { + output_layout.push_back(i); + } + } + CHECK_EQ(output_layout.size(), output_rank); + std::vector dimension_sizes; + for (int64 i = 0; i < output_rank; ++i) { + dimension_sizes.push_back(output_shape.dimensions(i)); + } + Shape output_shape_with_layout = MakeShapeWithLayout( + output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), + output_layout); + CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout)); + return output_shape_with_layout; +} + /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 9b9dfccc57..aaf8e84cfe 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -378,6 +379,15 @@ class ShapeUtil { static bool ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape); + // Find a physical layout for 'output_shape' such that + // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns + // true (where 'output_shape_with_layout' is 'output_shape' with the found + // layout). The layout of 'input_shape' is kept fixed. Returns + // 'output_shape_with_layout' if such a layout can be found, and an error + // otherwise. + static tensorflow::gtl::optional AlignLayouts( + const Shape& input_shape, const Shape& output_shape); + // Returns a shape with the given dimension deleted. // For example: // • `DeleteDimension(1, T[m, n, k]) = T[m, k]` diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 1f1f71d97c..73538b8b88 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { @@ -522,5 +523,58 @@ TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } +TEST(AlignmentTest, AlignLayoutsWithoutTrivialDimensions) { + Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, + {3, 2, 1, 0, 4}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 7, 5, 11})); + EXPECT_TRUE(aligned_shape); + EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), + ElementsAre(4, 3, 2, 1, 0, 5)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); + + aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {3, 2, 4, 35, 11})); + EXPECT_TRUE(aligned_shape); + EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), + ElementsAre(3, 2, 1, 0, 4)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); +} + +TEST(AlignmentTest, AlignLayoutsWithTrivialDimensions) { + Shape input = + ShapeUtil::MakeShapeWithLayout(xla::F32, {1, 3, 8, 1, 5, 7, 1, 11, 1, 1}, + {5, 0, 4, 2, 1, 3, 6, 7, 9, 8}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {1, 4, 1, 3, 2, 7, 5, 11, 1})); + EXPECT_TRUE(aligned_shape); + EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), + ElementsAre(6, 5, 4, 3, 1, 7, 0, 2, 8)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); +} + +// A test case where the consecutive elements of the input shape belonging to +// the same layout part are not in descending order. +TEST(AlignmentTest, AlignLayoutsWithoutTrivialDimensionsWrongInputLayout) { + // Same physical layout as in AlignLayoutsWithoutTrivialDimensions, except + // that the first two dimension numbers are exchanged. + Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, + {2, 3, 1, 0, 4}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 7, 5, 11})); + EXPECT_FALSE(aligned_shape); +} + +// A test case where the physical layout of the input shape does not place all +// dimensions that belong to the same alignment part consecutively. +TEST(AlignmentTest, + AlignLayoutsWithoutTrivialDimensionsNonConsecutiveAlignmentPart) { + Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, + {3, 2, 1, 0, 4}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 5, 77})); + EXPECT_FALSE(aligned_shape); +} + } // namespace } // namespace xla -- GitLab From 5cc5107cec12b8ec3afd2cb5f8a28777774f9339 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 16 May 2017 05:33:19 -0700 Subject: [PATCH 666/697] Add build rule for doc generation PiperOrigin-RevId: 156170472 --- tensorflow/BUILD | 2 +- tensorflow/tools/docs/BUILD | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 503ad79a38..0c2fdab9b8 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -458,7 +458,7 @@ filegroup( filegroup( name = "docs_src", - data = glob(["docs_src/**/*.md"]), + srcs = glob(["docs_src/**/*.md"]), ) # ------------------------------------------- diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 8e27b133c2..0c0239352b 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -151,3 +151,18 @@ filegroup( ], ), ) + +genrule( + name = "python_docs", + srcs = ["//tensorflow:docs_src"], + outs = ["python_docs.tgz"], + cmd = "STARTDIR=$$(pwd); " + + "TMP=$$(mktemp -d $${TMPDIR:-/tmp}/docs.XXXXXXXXXX); " + + "cd $$TMP; " + + "$$STARTDIR/$(location :generate) --src_dir=$$STARTDIR/third_party/tensorflow/docs_src --output_dir=docs_out; " + + "tar -czf $$STARTDIR/$@ docs_out; " + + "cd $$STARTDIR; " + + "rm -rf $$TMP; ", + tools = [":generate"], + visibility = ["//visibility:public"], +) -- GitLab From f7040ddf1ad5ebf6c871e987b038bfd5f343546e Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Tue, 16 May 2017 08:11:43 -0700 Subject: [PATCH 667/697] Adds learn_runner to open source tf.contrib.learn symbols. PiperOrigin-RevId: 156183214 --- tensorflow/contrib/learn/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/learn/__init__.py b/tensorflow/contrib/learn/__init__.py index 05c4024d0b..aec4911e29 100644 --- a/tensorflow/contrib/learn/__init__.py +++ b/tensorflow/contrib/learn/__init__.py @@ -88,9 +88,11 @@ from __future__ import print_function from tensorflow.contrib.learn.python.learn import * # pylint: enable=wildcard-import +from tensorflow.contrib.learn.python.learn import learn_runner + from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['datasets', 'head', 'io', 'models', +_allowed_symbols = ['datasets', 'head', 'io', 'learn_runner', 'models', 'monitors', 'NotFittedError', 'ops', 'preprocessing', 'utils', 'graph_actions'] -- GitLab From 3038bc913713bc31d0b67150e7cf7c056baba7e4 Mon Sep 17 00:00:00 2001 From: Adria Puigdomenech Date: Tue, 16 May 2017 09:29:45 -0700 Subject: [PATCH 668/697] Change rnn ops to check types via duck-typing, instead of a private attribute. PiperOrigin-RevId: 156190878 --- tensorflow/contrib/rnn/python/ops/core_rnn.py | 8 ++++---- .../contrib/rnn/python/ops/core_rnn_cell_impl.py | 15 ++++++++++----- tensorflow/contrib/rnn/python/ops/rnn_cell.py | 3 ++- .../seq2seq/python/ops/attention_wrapper.py | 2 +- .../contrib/seq2seq/python/ops/basic_decoder.py | 4 ++-- .../seq2seq/python/ops/beam_search_decoder.py | 4 ++-- tensorflow/python/ops/rnn.py | 15 +++++---------- tensorflow/python/ops/rnn_cell_impl.py | 7 +++++++ 8 files changed, 33 insertions(+), 25 deletions(-) diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn.py b/tensorflow/contrib/rnn/python/ops/core_rnn.py index bbfa6b8850..3ce075ce9c 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn.py @@ -19,7 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.rnn.python.ops import core_rnn_cell from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -32,6 +31,7 @@ from tensorflow.python.util import nest # pylint: disable=protected-access _concat = rnn_cell_impl._concat +_like_rnncell = rnn_cell_impl._like_rnncell _infer_state_dtype = rnn._infer_state_dtype _reverse_seq = rnn._reverse_seq _rnn_step = rnn._rnn_step @@ -99,7 +99,7 @@ def static_rnn(cell, inputs, initial_state=None, dtype=None, (column size) cannot be inferred from inputs via shape inference. """ - if not isinstance(cell, core_rnn_cell.RNNCell): + if not _like_rnncell(cell): raise TypeError("cell must be an instance of RNNCell") if not nest.is_sequence(inputs): raise TypeError("inputs must be a sequence") @@ -319,9 +319,9 @@ def static_bidirectional_rnn(cell_fw, cell_bw, inputs, ValueError: If inputs is None or an empty list. """ - if not isinstance(cell_fw, core_rnn_cell.RNNCell): + if not _like_rnncell(cell_fw): raise TypeError("cell_fw must be an instance of RNNCell") - if not isinstance(cell_bw, core_rnn_cell.RNNCell): + if not _like_rnncell(cell_bw): raise TypeError("cell_bw must be an instance of RNNCell") if not nest.is_sequence(inputs): raise TypeError("inputs must be a sequence") diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py index eba2c0d2ac..f3e57cd3ec 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py @@ -42,16 +42,21 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops.math_ops import sigmoid from tensorflow.python.ops.math_ops import tanh -from tensorflow.python.ops.rnn_cell_impl import _RNNCell as RNNCell from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +# pylint: disable=protected-access +RNNCell = rnn_cell_impl._RNNCell # pylint: disable=invalid-name +_like_rnncell = rnn_cell_impl._like_rnncell +# pylint: enable=protected-access + _BIAS_VARIABLE_NAME = "biases" _WEIGHTS_VARIABLE_NAME = "weights" @@ -424,7 +429,7 @@ class OutputProjectionWrapper(RNNCell): ValueError: if output_size is not positive. """ super(OutputProjectionWrapper, self).__init__(_reuse=reuse) - if not isinstance(cell, RNNCell): + if not _like_rnncell(cell): raise TypeError("The parameter cell is not RNNCell.") if output_size < 1: raise ValueError("Parameter output_size must be > 0: %d." % output_size) @@ -480,7 +485,7 @@ class InputProjectionWrapper(RNNCell): super(InputProjectionWrapper, self).__init__(_reuse=reuse) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) - if not isinstance(cell, RNNCell): + if not _like_rnncell(cell): raise TypeError("The parameter cell is not RNNCell.") self._cell = cell self._num_proj = num_proj @@ -556,7 +561,7 @@ class DropoutWrapper(RNNCell): TypeError: if cell is not an RNNCell. ValueError: if any of the keep_probs are not between 0 and 1. """ - if not isinstance(cell, RNNCell): + if not _like_rnncell(cell): raise TypeError("The parameter cell is not a RNNCell.") with ops.name_scope("DropoutWrapperInit"): def tensor_and_const_value(v): @@ -791,7 +796,7 @@ class EmbeddingWrapper(RNNCell): ValueError: if embedding_classes is not positive. """ super(EmbeddingWrapper, self).__init__(_reuse=reuse) - if not isinstance(cell, RNNCell): + if not _like_rnncell(cell): raise TypeError("The parameter cell is not RNNCell.") if embedding_classes <= 0 or embedding_size <= 0: raise ValueError("Both embedding_classes and embedding_size must be > 0: " diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 7a0f894404..217c379c36 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -1057,7 +1058,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell): `state_is_tuple` is `False` or if attn_length is zero or less. """ super(AttentionCellWrapper, self).__init__(_reuse=reuse) - if not isinstance(cell, core_rnn_cell.RNNCell): + if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError("The parameter cell is not RNNCell.") if nest.is_sequence(cell.state_size) and not state_is_tuple: raise ValueError("Cell returns tuple of states, but the flag " diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 686a85e4e7..fd76882d84 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -543,7 +543,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell): name: Name to use when creating ops. """ super(AttentionWrapper, self).__init__(name=name) - if not isinstance(cell, core_rnn_cell.RNNCell): + if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError( "cell must be an RNNCell, saw type: %s" % type(cell).__name__) if not isinstance(attention_mechanism, AttentionMechanism): diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py index 6231a1fdf9..8ae175b6b5 100644 --- a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py @@ -21,13 +21,13 @@ from __future__ import print_function import collections -from tensorflow.contrib.rnn import core_rnn_cell from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.contrib.seq2seq.python.ops import helper as helper_py from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as layers_base +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.util import nest @@ -60,7 +60,7 @@ class BasicDecoder(decoder.Decoder): Raises: TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. """ - if not isinstance(cell, core_rnn_cell.RNNCell): + if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) if not isinstance(helper, helper_py.Helper): raise TypeError("helper must be a Helper, received: %s" % type(helper)) diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 63ce9dafc0..eb494bda4b 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -20,7 +20,6 @@ from __future__ import print_function import collections -from tensorflow.contrib.rnn import core_rnn_cell from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.python.framework import dtypes @@ -33,6 +32,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest @@ -143,7 +143,7 @@ class BeamSearchDecoder(decoder.Decoder): ValueError: If `start_tokens` is not a vector or `end_token` is not a scalar. """ - if not isinstance(cell, core_rnn_cell.RNNCell): + if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) if (output_layer is not None and not isinstance(output_layer, layers_base.Layer)): diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 475c49091e..2aa288e36a 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -34,6 +34,7 @@ from tensorflow.python.util import nest # pylint: disable=protected-access _concat = rnn_cell_impl._concat +_like_rnncell = rnn_cell_impl._like_rnncell # pylint: enable=protected-access @@ -361,12 +362,10 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. """ - # pylint: disable=protected-access - if not isinstance(cell_fw, rnn_cell_impl._RNNCell): + if not _like_rnncell(cell_fw): raise TypeError("cell_fw must be an instance of RNNCell") - if not isinstance(cell_bw, rnn_cell_impl._RNNCell): + if not _like_rnncell(cell_bw): raise TypeError("cell_bw must be an instance of RNNCell") - # pylint: enable=protected-access with vs.variable_scope(scope or "bidirectional_rnn"): # Forward direction @@ -507,10 +506,8 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, ValueError: If inputs is None or an empty list. """ - # pylint: disable=protected-access - if not isinstance(cell, rnn_cell_impl._RNNCell): + if not _like_rnncell(cell): raise TypeError("cell must be an instance of RNNCell") - # pylint: enable=protected-access # By default, time_major==False and inputs are batch-major: shaped # [batch, time, depth] @@ -921,10 +918,8 @@ def raw_rnn(cell, loop_fn, a `callable`. """ - # pylint: disable=protected-access - if not isinstance(cell, rnn_cell_impl._RNNCell): + if not _like_rnncell(cell): raise TypeError("cell must be an instance of RNNCell") - # pylint: enable=protected-access if not callable(loop_fn): raise TypeError("loop_fn must be a callable") diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 10d23eb09f..9c0fb1db23 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -36,6 +36,13 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.util import nest +def _like_rnncell(cell): + """Checks that a given object is an RNNCell by using duck typing.""" + conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"), + hasattr(cell, "zero_state"), callable(cell)] + return all(conditions) + + def _concat(prefix, suffix, static=False): """Concat that enables int, Tensor, or TensorShape values. -- GitLab From c30ae04407940be85678cafb25747e2412e74db9 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Tue, 16 May 2017 09:33:17 -0700 Subject: [PATCH 669/697] Bugfix: tf.reverse supports string type. PiperOrigin-RevId: 156191186 --- tensorflow/core/kernels/reverse_op.cc | 1 + tensorflow/core/ops/array_ops.cc | 4 ++-- tensorflow/python/kernel_tests/array_ops_test.py | 8 ++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc index 24b3ba31b8..6f7a0a4df5 100644 --- a/tensorflow/core/kernels/reverse_op.cc +++ b/tensorflow/core/kernels/reverse_op.cc @@ -266,6 +266,7 @@ class ReverseV2Op : public OpKernel { .HostMemory("axis"), \ ReverseV2Op) TF_CALL_POD_TYPES(REGISTER_KERNELS); +TF_CALL_string(REGISTER_KERNELS); #undef REGISTER_KERNELS #if GOOGLE_CUDA diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 3d52a29de5..b9e56a1742 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1004,7 +1004,7 @@ REGISTER_OP("Reverse") .Output("output: T") .Attr( "T: {uint8, int8, int32, int64, bool, half, float, double, complex64, " - "complex128}") + "complex128, string}") .SetShapeFn([](InferenceContext* c) { ShapeHandle input = c->input(0); ShapeHandle dims; @@ -1081,7 +1081,7 @@ REGISTER_OP("ReverseV2") .Attr("Tidx: {int32, int64} = DT_INT32") .Attr( "T: {uint8, int8, int32, int64, bool, half, float, double, complex64, " - "complex128}") + "complex128, string}") .SetShapeFn([](InferenceContext* c) { ShapeHandle input = c->input(0); ShapeHandle axis; diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 49695dd3ca..7b8cd25664 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -240,7 +240,7 @@ class ReverseV2Test(test_util.TensorFlowTestCase): self.assertAllEqual(x_tf, x_np) def _reverse1DimAuto(self, np_dtype): - x_np = np.array([1, 2, 3, 4, 5], dtype=np_dtype) + x_np = np.array([1, 200, 3, 40, 5], dtype=np_dtype) for use_gpu in [False, True]: with self.test_session(use_gpu=use_gpu): @@ -248,7 +248,7 @@ class ReverseV2Test(test_util.TensorFlowTestCase): self.assertAllEqual(x_tf, np.asarray(x_np)[::-1]) def _reverse2DimAuto(self, np_dtype): - x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np_dtype) + x_np = np.array([[1, 200, 3], [4, 5, 60]], dtype=np_dtype) for reverse_f in [array_ops.reverse_v2, array_ops.reverse]: for use_gpu in [False, True]: @@ -283,14 +283,14 @@ class ReverseV2Test(test_util.TensorFlowTestCase): def testReverse1DimAuto(self): for dtype in [ np.uint8, np.int8, np.int32, np.int64, np.bool, np.float16, np.float32, - np.float64, np.complex64, np.complex128 + np.float64, np.complex64, np.complex128, np.array(b"").dtype.type ]: self._reverse1DimAuto(dtype) def testReverse2DimAuto(self): for dtype in [ np.uint8, np.int8, np.int32, np.int64, np.bool, np.float16, np.float32, - np.float64, np.complex64, np.complex128 + np.float64, np.complex64, np.complex128, np.array(b"").dtype.type ]: self._reverse2DimAuto(dtype) -- GitLab From 5e3545c24eeb178ce8231067f1c8c42c7c42750f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 16 May 2017 09:50:45 -0700 Subject: [PATCH 670/697] [TF:XLA] Fix compilation bug for computations with one non-constant output and one or more variable outputs. PiperOrigin-RevId: 156193188 --- tensorflow/compiler/tests/variable_ops_test.py | 16 ++++++++++++++++ tensorflow/compiler/tf2xla/xla_compiler.cc | 4 ++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index dcb9e2db2f..fef390fd67 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -36,6 +37,21 @@ from tensorflow.python.training.gradient_descent import GradientDescentOptimizer class VariableOpsTest(XLATestCase): """Test cases for resource variable operators.""" + def testOneWriteOneOutput(self): + # Regression test for a bug where computations with one non-constant + # output and one variable update were mishandled. + for dtype in self.numeric_types: + init = np.array([[1, 2], [3, 4]], dtype=dtype) + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + sess.run(variables.variables_initializer([v])) + p = array_ops.placeholder(dtype) + x = v.assign_add(p) + with ops.control_dependencies([x]): + y = v.read_value() + self.assertAllClose(np.array([[2, 3], [4, 5]], dtype=dtype), + sess.run(y, {p: 1})) + def testReadWrite(self): """Tests initialization, reading, and writing a resource variable.""" with self.test_session() as session: diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index d246e7f9ac..a8034a2ec6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -491,10 +491,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, i < context->retvals().size(); ++i) { const XlaContext::HandleOrConstant& retval = context->retvals()[i]; if (!retval.is_constant) { - CHECK_LT(computation_output, num_nonconst_outputs); + CHECK_LT(computation_output, num_computation_outputs); OutputDescription& output = result->outputs[i]; output.is_constant = false; - if (num_nonconst_outputs > 1) { + if (num_computation_outputs > 1) { output.shape = XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( result->xla_output_shape, computation_output)); -- GitLab From 8d758e1e1843202e273e88f35dcf3bc8f4803906 Mon Sep 17 00:00:00 2001 From: Leandro Gracia Gil Date: Wed, 17 May 2017 01:56:08 +0900 Subject: [PATCH 671/697] Make possible to use static libraries generated by tfcompile in MSVC. (#9908) Currently it is possible to generate MSVC-compatible static libraries from a Linux build of tfcompile using the target triple "x86_64-pc-windows-msvc". However, just linking such library does not work out-of-the-box because of a trivial class -> struct issue that becomes a compiler error in MSVC and because multiple other symbols are missing. The files defining these symbols must be manually built into the MSVC project, but many have minor incompatibilities that prevent them to be built. This patch makes a few minor modifications to these files, making it possible to link and run Tensorflow graphs generated with tfcompile in MSVC in Windows. The files that might need to be built into the MSVC project are as follows: 1. For symbols required by the generated headers. - tensorflow/compiler/aot/runtime.cc - tensorflow/compiler/xla/executable_run_options.cc 2. For symbols that provide custom XLA operations. - tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc - tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc - tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc - tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc 3. For symbols required by SimpleResolver in xla/service/cpu/simple_orc_jit.cc. - tensorflow/compiler/xla/service/cpu/runtime_matmul.cc - tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc - tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc - tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc - tensorflow/compiler/xla/service/cpu/cpu_runtime.cc Additionally, the following files might also required by SimpleResolver, but are not included in this patch because there is no simple alternative to __attribute__((vector_size(x))) and __attribute__((weak)) in MSVC. - tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc - tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc See https://github.com/tensorflow/tensorflow/issues/8310 for more information. --- tensorflow/compiler/aot/codegen.cc | 2 +- tensorflow/compiler/aot/codegen_test_h.golden | 2 +- tensorflow/compiler/aot/runtime.cc | 10 +++++++++- .../tf2xla/kernels/gather_op_kernel_float_int32.cc | 3 ++- .../tf2xla/kernels/gather_op_kernel_float_int64.cc | 3 ++- .../tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc | 3 ++- .../tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc | 3 ++- tensorflow/compiler/tf2xla/xla_local_runtime_context.h | 2 +- tensorflow/compiler/xla/service/cpu/cpu_runtime.cc | 1 - tensorflow/compiler/xla/service/cpu/runtime_matmul.cc | 4 ++-- .../xla/service/cpu/runtime_single_threaded_matmul.cc | 4 ++-- tensorflow/compiler/xla/tests/custom_call_test.cc | 7 ++++--- 12 files changed, 28 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index a53e82d34b..bbdb342a62 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -365,7 +365,7 @@ Status GenerateHeader(const HeaderOpts& opts, const Config& config, #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -namespace Eigen { class ThreadPoolDevice; } +namespace Eigen { struct ThreadPoolDevice; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 46d7c03006..01963c6df4 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -15,7 +15,7 @@ #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -namespace Eigen { class ThreadPoolDevice; } +namespace Eigen { struct ThreadPoolDevice; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/aot/runtime.cc index 208de5498d..5772776666 100644 --- a/tensorflow/compiler/aot/runtime.cc +++ b/tensorflow/compiler/aot/runtime.cc @@ -31,6 +31,8 @@ namespace { inline void* aligned_malloc(size_t size, int minimum_alignment) { #if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN) return memalign(minimum_alignment, size); +#elif defined(COMPILER_MSVC) + return _aligned_malloc(size, minimum_alignment); #else // !__ANDROID__ && !OS_ANDROID && !OS_CYGWIN void* ptr = nullptr; // posix_memalign requires that the requested alignment be at least @@ -45,7 +47,13 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) { #endif } -inline void aligned_free(void* aligned_memory) { free(aligned_memory); } +inline void aligned_free(void* aligned_memory) { +#if defined(COMPILER_MSVC) + _aligned_free(aligned_memory); +#else + free(aligned_memory); +#endif +} size_t align_to(size_t n, size_t align) { return (((n - 1) / align) + 1) * align; diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc index eff23bd77d..ef844cc6c5 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -63,7 +64,7 @@ EIGEN_STRONG_INLINE void gather_float_int32_xla_impl(float* out, void** data) { // Implements gather on CPU. This is called by an XLA custom call, set up by // gather_op.cc. -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT gather_float_int32_xla_impl(float* out, void** data) { tensorflow::gather_float_int32_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc index ae31f6f200..4c8693d197 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -63,7 +64,7 @@ EIGEN_STRONG_INLINE void gather_float_int64_xla_impl(float* out, void** data) { // Implements gather on CPU. This is called by an XLA custom call, set up by // gather_op.cc. -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT gather_float_int64_xla_impl(float* out, void** data) { tensorflow::gather_float_int64_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc index 0033a949a3..a71f2fcf0f 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -43,7 +44,7 @@ EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) { // Implements argmax on CPU. This is called by an XLA custom call, set up by // index_ops.cc. -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) { tensorflow::argmax_float_1d_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc index be8ad2317c..f30eb6121f 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -45,7 +46,7 @@ EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) { // Implements argmax on CPU. This is called by an XLA custom call, set up by // index_ops.cc. -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) { tensorflow::argmax_float_2d_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h index cd773d64ed..dca420d6ee 100644 --- a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h +++ b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h @@ -23,7 +23,7 @@ limitations under the License. // actually used. E.g. some ahead-of-time compiled computations don't need a // thread pool. namespace Eigen { -class ThreadPoolDevice; +struct ThreadPoolDevice; } namespace tensorflow { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 8e06f0520e..253de20f25 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include #include #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index 677080a862..332f4216dc 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -53,8 +53,8 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, typedef typename Eigen::Tensor::DimensionPair DimPair; int lhs_contract_dim = transpose_lhs ? 0 : 1; int rhs_contract_dim = transpose_rhs ? 1 : 0; - const Eigen::array dims( - DimPair(lhs_contract_dim, rhs_contract_dim)); + const Eigen::array dims({ + DimPair(lhs_contract_dim, rhs_contract_dim) }); // Matrix multiply is a special case of the "contract" operation where // the contraction is performed along dimension 1 of the lhs and dimension diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 384a978873..e45329c4ef 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -47,8 +47,8 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, typedef typename Eigen::Tensor::DimensionPair DimPair; int lhs_contract_dim = transpose_lhs ? 0 : 1; int rhs_contract_dim = transpose_rhs ? 1 : 0; - const Eigen::array dims( - DimPair(lhs_contract_dim, rhs_contract_dim)); + const Eigen::array dims({ + DimPair(lhs_contract_dim, rhs_contract_dim)}); // Matrix multiply is a special case of the "contract" operation where // the contraction is performed along dimension 1 of the lhs and dimension diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index dc54c9defe..8b5b38b0b4 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -29,22 +29,23 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT R0F32Add2(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*)); *out = **in + 2.0f; } -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT R2F32ReduceSum(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; *out = array[0] + array[1] + array[2] + array[3]; } -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT Add1ToValues(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; -- GitLab From 7b77d12d4dda833c9aab18e6848130f15d699bd0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 May 2017 10:03:48 -0700 Subject: [PATCH 672/697] Update ops-related pbtxt files. PiperOrigin-RevId: 156194847 --- .../core/ops/compat/ops_history.v1.pbtxt | 81 +++++++++++++++++++ tensorflow/core/ops/ops.pbtxt | 2 + 2 files changed, 83 insertions(+) diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 5dd18d4ba3..22d9c35116 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -18607,6 +18607,40 @@ op { } } } +op { + name: "Reverse" + input_arg { + name: "tensor" + type_attr: "T" + } + input_arg { + name: "dims" + type: DT_BOOL + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_UINT8 + type: DT_INT8 + type: DT_INT32 + type: DT_INT64 + type: DT_BOOL + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } +} op { name: "ReverseSequence" input_arg { @@ -18696,6 +18730,53 @@ op { } } } +op { + name: "ReverseV2" + input_arg { + name: "tensor" + type_attr: "T" + } + input_arg { + name: "axis" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_UINT8 + type: DT_INT8 + type: DT_INT32 + type: DT_INT64 + type: DT_BOOL + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } +} op { name: "Rint" input_arg { diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index b0304f2486..0ff321aed5 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -19082,6 +19082,7 @@ op { type: DT_DOUBLE type: DT_COMPLEX64 type: DT_COMPLEX128 + type: DT_STRING } } } @@ -19183,6 +19184,7 @@ op { type: DT_DOUBLE type: DT_COMPLEX64 type: DT_COMPLEX128 + type: DT_STRING } } } -- GitLab From 36f44df832fb7cc8c893819813a11d96b953aa8f Mon Sep 17 00:00:00 2001 From: "freedom\" Koan-Sin Tan" Date: Wed, 17 May 2017 01:12:29 +0800 Subject: [PATCH 673/697] gif decoder returns 4-D tensor, remove the first dim (#9877) * gif decoder returns 4-D tensor, remove the first dim * Mention that only single-frame GIFs are supported --- tensorflow/examples/label_image/main.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc index 3351109f45..d98a5c31ab 100644 --- a/tensorflow/examples/label_image/main.cc +++ b/tensorflow/examples/label_image/main.cc @@ -30,6 +30,9 @@ limitations under the License. // the top of the main() function. // // The googlenet_graph.pb file included by default is created from Inception. +// +// Note that, for GIF inputs, to reuse existing code, only single-frame ones +// are supported. #include #include @@ -103,7 +106,10 @@ Status ReadTensorFromImageFile(const string& file_name, const int input_height, image_reader = DecodePng(root.WithOpName("png_reader"), file_reader, DecodePng::Channels(wanted_channels)); } else if (tensorflow::StringPiece(file_name).ends_with(".gif")) { - image_reader = DecodeGif(root.WithOpName("gif_reader"), file_reader); + // gif decoder returns 4-D tensor, remove the first dim + image_reader = Squeeze(root.WithOpName("squeeze_first_dim"), + DecodeGif(root.WithOpName("gif_reader"), + file_reader)); } else { // Assume if it's neither a PNG nor a GIF then it must be a JPEG. image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader, -- GitLab From deed76c536eb91f61ac6a31ce3c5175db60eb0a4 Mon Sep 17 00:00:00 2001 From: David Norman Date: Tue, 16 May 2017 18:13:14 +0100 Subject: [PATCH 674/697] [XLA] Add F16 support to the Literal protobuf and LiteralUtils class. (#9913) * Add F16 support to the Literal protobuf and LiteralUtils class. No support has been added to any public backend, however the unit tests demonstrate that the literals can store and retreive data correctly. * Changes after code review * Use an alternative form of local initialization * Change a few more C-style casts to C++ casts --- tensorflow/compiler/tf2xla/xla_helpers.cc | 8 +- tensorflow/compiler/xla/literal_util.cc | 65 +++++++- tensorflow/compiler/xla/literal_util.h | 167 +++++++------------ tensorflow/compiler/xla/literal_util_test.cc | 71 ++++++++ tensorflow/compiler/xla/primitive_util.cc | 5 + tensorflow/compiler/xla/primitive_util.h | 6 + tensorflow/compiler/xla/types.h | 4 + tensorflow/compiler/xla/xla_data.proto | 1 + 8 files changed, 216 insertions(+), 111 deletions(-) diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 10d8b67bbd..f8589edafc 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -89,7 +90,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::F16: - LOG(FATAL) << "f16 literals not yet implemented"; + literal = *xla::LiteralUtil::CreateR0( + static_cast(value)); + break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; case xla::OPAQUE: @@ -107,6 +110,9 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); switch (type) { + case xla::F16: + return b->ConstantR0(static_cast(value)); + break; case xla::F32: return b->ConstantR0(static_cast(value)); break; diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index e3bc856fc0..0f622f9153 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -148,6 +148,9 @@ template case S64: return CopyRange(src_literal, src_base, dest_literal, dest_base, copy_size); + case F16: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); case F32: return CopyRange(src_literal, src_base, dest_literal, dest_base, copy_size); @@ -178,6 +181,8 @@ template return *LiteralUtil::CreateR0(0); case S64: return *LiteralUtil::CreateR0(0); + case F16: + return *LiteralUtil::CreateR0(static_cast(0.0f)); case F32: return *LiteralUtil::CreateR0(0); case F64: @@ -187,8 +192,6 @@ template case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; - case F16: - LOG(FATAL) << "f16 literals not yet implemented"; case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 0"; case OPAQUE: @@ -222,7 +225,7 @@ template case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *LiteralUtil::CreateR0(static_cast(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: @@ -258,7 +261,8 @@ template case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *LiteralUtil::CreateR0( + static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -294,7 +298,8 @@ template case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *LiteralUtil::CreateR0( + static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -498,6 +503,8 @@ template return tensorflow::strings::StrCat(Get(literal, multi_index)); case F64: return tensorflow::strings::StrCat(Get(literal, multi_index)); + case F16: + return tensorflow::strings::StrCat(Get(literal, multi_index)); default: return tensorflow::strings::StrCat( "[", PrimitiveType_Name(literal.shape().element_type()), "]"); @@ -652,6 +659,8 @@ template return reinterpret_cast(literal.f32s().data()); case F64: return reinterpret_cast(literal.f64s().data()); + case F16: + return reinterpret_cast(literal.f16s().data()); default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(literal.shape().element_type()); @@ -691,6 +700,8 @@ template break; case F64: Resize(num_elements, 0, literal); + case F16: + Resize(num_elements, static_cast(0.0f), literal); break; default: LOG(FATAL) << "primitive type not supported in literals: " @@ -728,6 +739,9 @@ template case F64: actual = literal.f64s_size(); break; + case F16: + actual = literal.f16s().size() / sizeof(half); + break; default: return tensorflow::errors::Unimplemented( "unhandled element type for literal validation: " + @@ -818,6 +832,8 @@ bool EqualElements(const Literal& literal1, const Literal& literal2, return EqualElements(literal1, literal2, 0, &multi_index); case F64: return EqualElements(literal1, literal2, 0, &multi_index); + case F16: + return EqualElements(literal1, literal2, 0, &multi_index); default: LOG(FATAL) << "Unimplemented: LiteralUtil::Equal for type " << PrimitiveType_Name(literal1.shape().element_type()); @@ -916,6 +932,19 @@ LiteralUtil::GetMutableArraySlice(Literal* literal) { values->size()); } +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + // C++11 standard, basic_string 21.4.1.5, values should be stored + // contiguously. From C++17 a mutable data() member will be provided. + // TODO - there is an endianess problem here. fix it, or wait for uint16 + // support in protobuf + auto values = literal->mutable_f16s(); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(&(*values)[0]), + values->size() / sizeof(half)); +} + template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( const Literal& literal) { @@ -976,6 +1005,15 @@ LiteralUtil::GetArraySlice(const Literal& literal) { return literal.f64s(); } +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), F16); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(literal.f16s().data()), + literal.f16s().size() / sizeof(half)); +} + template static bool AllElementsEqualValue(const Literal& literal, NativeT value) { for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { @@ -1015,6 +1053,8 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { return AllElementsEqualValue(literal, value); case F64: return AllElementsEqualValue(literal, value); + case F16: + return AllElementsEqualValue(literal, static_cast(value)); case PRED: if (value == 0) { return AllElementsEqualValue(literal, false); @@ -1034,6 +1074,8 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { return AllElementsEqualValue(literal, value); case F64: return AllElementsEqualValue(literal, value); + case F16: + return AllElementsEqualValue(literal, static_cast(value)); default: return false; } @@ -1058,6 +1100,8 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { return Get(literal, indices) == 0.0f; case F64: return Get(literal, indices) == 0.0; + case F16: + return Get(literal, indices) == static_cast(0.0f); case PRED: return Get(literal, indices) == false; default: @@ -1128,4 +1172,15 @@ template <> literal->mutable_f64s()->Resize(num_elements, value); } +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, half value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_f16s()->resize(num_elements * sizeof(half)); + auto data = GetMutableArraySlice(literal); + for (int i = 0; i < num_elements; i++) { + data[i] = value; + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 0ea7186040..2da010d56e 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -505,6 +505,10 @@ template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice(const Literal& literal); +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal); + template <> /* static */ tensorflow::gtl::MutableArraySlice LiteralUtil::GetMutableArraySlice(Literal* literal); @@ -541,6 +545,50 @@ template <> /* static */ tensorflow::gtl::MutableArraySlice LiteralUtil::GetMutableArraySlice(Literal* literal); +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, bool value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int8 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint8 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int32 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint32 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, float value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, double value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, half value, + Literal* literal); + template /* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { auto literal = MakeUnique(); @@ -770,6 +818,14 @@ template <> return literal.u8s()[linear_index]; } +template <> +/* static */ inline half LiteralUtil::Get( + const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { + CHECK(literal.shape().element_type() == F16); + int64 linear_index = LinearIndex(literal, multi_index); + return GetArraySlice(literal)[linear_index]; +} + template /* static */ void LiteralUtil::Set( Literal* literal, tensorflow::gtl::ArraySlice multi_index, @@ -834,76 +890,13 @@ template } while (IndexUtil::BumpIndices(literal.shape(), &indices)); } -template <> -/* static */ inline void LiteralUtil::PopulateR0(bool value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_preds()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(uint8 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_u8s()->push_back(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(int8 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_u8s()->push_back(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(uint32 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_u32s()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(int32 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_s32s()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(uint64 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_u64s()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(int64 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_s64s()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(float value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_f32s()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(double value, - Literal* literal) { +template +/* static */ inline void LiteralUtil::PopulateR0(NativeT value, + Literal* literal) { *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_f64s()->Add(value); + ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {}); + Resize(1, value, literal); } template @@ -1116,42 +1109,6 @@ template return result_literal; } -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, bool value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int8 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint8 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int32 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint32 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, float value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, double value, - Literal* literal); - template /* static */ std::unique_ptr LiteralUtil::CreateFullWithMonotonicDim0MajorLayout( diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 7acb9933da..9a09822174 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -105,6 +105,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f32_lit = LiteralUtil::CreateR0(3.14f); ASSERT_EQ("3.14", LiteralUtil::ToString(*f32_lit)); + + auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); + ASSERT_EQ("0.5", LiteralUtil::ToString(*f16_lit)); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -373,6 +376,15 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE( LiteralUtil::IsAll(*LiteralUtil::CreateR2({{9, 8}, {8, 8}}), 8)); + half h8(8.0f); + half h9(9.0f); + EXPECT_TRUE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h8}}), 8)); + EXPECT_FALSE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h9}}), 8)); + EXPECT_FALSE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h9}, {h8}}), 8)); + auto uint64_max = std::numeric_limits::max(); EXPECT_FALSE(LiteralUtil::IsAll( *LiteralUtil::CreateR2( @@ -659,6 +671,30 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); } +TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { + Literal output; + half h(0.25f); + LiteralUtil::PopulateWithValue(h, {}, &output); + auto expected = LiteralUtil::CreateR0(h); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { + Literal output; + half h(0.5f); + LiteralUtil::PopulateWithValue(h, {3}, &output); + auto expected = LiteralUtil::CreateR1({h, h, h}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { + Literal output; + half h(2.0f); + LiteralUtil::PopulateWithValue(h, {2, 2}, &output); + auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + TEST_F(LiteralUtilTest, ReplicateR2U32) { auto input = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); @@ -730,6 +766,41 @@ TEST_F(LiteralUtilTest, CopyScalars) { EXPECT_EQ(LiteralUtil::Get(*vect, {4}), 17); } +TEST_F(LiteralUtilTest, F16) { + // Verify that the internal data views are consistent and that they + // are in little endian format + // TODO - modify if we make the data format machine endianess dependent + auto m1 = LiteralUtil::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); + Literal* l1 = m1.get(); + const char* d1 = (const char*)LiteralUtil::InternalData(*l1); + EXPECT_EQ(d1[0], 0); + EXPECT_EQ(d1[1], 0); + EXPECT_EQ(d1[2], 0); + EXPECT_EQ(d1[3], 0); + EXPECT_EQ(d1[4], 0); + EXPECT_EQ(d1[5], 0); + EXPECT_EQ(d1[6], 0); + EXPECT_EQ(d1[7], 0); + EXPECT_EQ(LiteralUtil::InternalData(*l1), + LiteralUtil::MutableInternalData(l1)); + + half h1(1.0f); + half h2(2.0f); + auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); + Literal* l2 = m2.get(); + const char* d2 = (const char*)LiteralUtil::InternalData(*l2); + EXPECT_EQ(d2[0], 0); + EXPECT_EQ(d2[1], 0x3C); + EXPECT_EQ(d2[2], 0); + EXPECT_EQ(d2[3], 0x40); + EXPECT_EQ(d2[4], 0); + EXPECT_EQ(d2[5], 0x40); + EXPECT_EQ(d2[6], 0); + EXPECT_EQ(d2[7], 0x3C); + EXPECT_EQ(LiteralUtil::InternalData(*l2), + LiteralUtil::MutableInternalData(l2)); +} + TEST_F(LiteralUtilTest, Populate) { struct PopulateData { std::vector dimensions; diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index e3909ae8e9..e4e37177a2 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -78,6 +78,11 @@ PrimitiveType NativeToPrimitiveType() { return F64; } +template <> +PrimitiveType NativeToPrimitiveType() { + return F16; +} + bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64; } diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 78f0ee6f59..162a11c7d2 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -75,6 +75,8 @@ template <> PrimitiveType NativeToPrimitiveType(); template <> PrimitiveType NativeToPrimitiveType(); +template <> +PrimitiveType NativeToPrimitiveType(); bool IsFloatingPointType(PrimitiveType type); @@ -150,6 +152,10 @@ template <> struct PrimitiveTypeToNative { using type = double; }; +template <> +struct PrimitiveTypeToNative { + using type = half; +}; } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index 8258031a2c..8d8e66715a 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/core/platform/types.h" +#include + namespace xla { using ::tensorflow::string; @@ -32,6 +34,8 @@ using ::tensorflow::uint16; using ::tensorflow::uint32; using ::tensorflow::uint64; +using ::Eigen::half; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TYPES_H_ diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index e5b94fcefe..52189fb5d7 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -286,6 +286,7 @@ message Literal { repeated float f32s = 8; repeated double f64s = 9; repeated Literal tuple_literals = 10; + bytes f16s = 11; // Note: the F16s are encoded in little endian byte order } message WindowDimension { -- GitLab From 970e81f2d192a7b47a5f10306574fd3d3f5cb2fb Mon Sep 17 00:00:00 2001 From: Ikaro Silva Date: Tue, 16 May 2017 13:16:17 -0400 Subject: [PATCH 675/697] [DOCS] Updating PredictionType args definition (#9919) [DOCS] Updating PredictionType value for DynamicRNN --- .../learn/python/learn/estimators/dynamic_rnn_estimator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index fc092fccd7..1724d7599d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -597,8 +597,8 @@ class DynamicRnnEstimator(estimator.Estimator): `ProblemType.CLASSIFICATION` or `ProblemType.LINEAR_REGRESSION`. prediction_type: whether the `Estimator` should return a value for each step in the sequence, or just a single value for the final time step. - Must be one of `ProblemType.SINGLE_VALUE` or - `ProblemType.MULTIPLE_VALUE`. + Must be one of `PredictionType.SINGLE_VALUE` or + `PredictionType.MULTIPLE_VALUE`. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the iterable should be instances of classes derived from `FeatureColumn`. -- GitLab From e58236a0586462906fbae2bee6713eb24cad869e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 May 2017 10:31:58 -0700 Subject: [PATCH 676/697] Clarify error message. The checked condition is not whether model_fn has four arguments, but whether the fourth argument is called 'params'. PiperOrigin-RevId: 156198632 --- .../contrib/learn/python/learn/estimators/estimator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index e8142b659b..1af31b933b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -1085,8 +1085,9 @@ class Estimator(BaseEstimator): # Check number of arguments of the given function matches requirements. model_fn_args = _model_fn_args(model_fn) if params is not None and 'params' not in model_fn_args: - raise ValueError('Estimator\'s model_fn (%s) has less than 4 ' - 'arguments, but not None params (%s) are passed.' % + raise ValueError('Estimator\'s model_fn (%s) does not have a params ' + 'argument, but params (%s) were passed to the ' + 'Estimator\'s constructor.' % (model_fn, params)) if params is None and 'params' in model_fn_args: logging.warning('Estimator\'s model_fn (%s) includes params ' -- GitLab From 69433c1f1adef96fde2074b05d3362e88d8587de Mon Sep 17 00:00:00 2001 From: Mathew Wicks Date: Wed, 17 May 2017 06:12:27 +1200 Subject: [PATCH 677/697] TFLearn Estimator Summary Writer Fix (#7555) * Fix evaluation summary writing. TFLearn * Update estimator.py Fixed White Space * Fix similar issue. Fix similar issue under, `tensorflow/python/estimators/estimator.py` * Filter out "global_step" This is seemingly the best way to ensure we don't write "global_step" to the summary, (which would be useless), while not modifying the dictionary. (Pull request for the other `estimator.py` incoming.) * Filter "global_step" [part 2] Same as other `estimator.py`. --- .../contrib/learn/python/learn/estimators/estimator.py | 10 ++++++++-- tensorflow/python/estimator/estimator.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index e8142b659b..3f54e5ee2c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -331,13 +331,19 @@ def _write_dict_to_summary(output_dir, for key in dictionary: if dictionary[key] is None: continue + if key == "global_step": + continue value = summary_proto.value.add() value.tag = key - if (isinstance(dictionary[key], np.float32) or + if (isinstance(dictionary[key], np.float32) or isinstance(dictionary[key], float)): value.simple_value = float(dictionary[key]) + elif (isinstance(dictionary[key], np.int64) or + isinstance(dictionary[key], np.int32) or + isinstance(dictionary[key], int)): + value.simple_value = int(dictionary[key]) else: - logging.warn('Skipping summary for %s, must be a float or np.float32.', + logging.warn('Skipping summary for %s, must be a float, np.float32, np.int64, np.int32 or int.', key) summary_writer.add_summary(summary_proto, current_global_step) summary_writer.flush() diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 735bfca4f7..f32567b880 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -817,13 +817,19 @@ def _write_dict_to_summary(output_dir, for key in dictionary: if dictionary[key] is None: continue + if key == "global_step": + continue value = summary_proto.value.add() value.tag = key - if (isinstance(dictionary[key], np.float32) or + if (isinstance(dictionary[key], np.float32) or isinstance(dictionary[key], float)): value.simple_value = float(dictionary[key]) + elif (isinstance(dictionary[key], np.int64) or + isinstance(dictionary[key], np.int32) or + isinstance(dictionary[key], int)): + value.simple_value = int(dictionary[key]) else: - logging.warn('Skipping summary for %s, must be a float or np.float32.', + logging.warn('Skipping summary for %s, must be a float, np.float32, np.int64, np.int32 or int.', key) summary_writer.add_summary(summary_proto, current_global_step) summary_writer.flush() -- GitLab From bc59b80fc8688b4249a8bbec424d1d1fb8d7dd7b Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Tue, 16 May 2017 11:24:01 -0700 Subject: [PATCH 678/697] LinearOperator.apply --> .matmul and .matvec LinearOperator.solve --> .solve and .solvevec PiperOrigin-RevId: 156206294 --- .../bijectors/affine_linear_operator_impl.py | 2 +- .../python/ops/mvn_linear_operator.py | 10 +- .../ops/vector_laplace_linear_operator.py | 10 +- .../linear_operator_composition_test.py | 12 +- .../kernel_tests/linear_operator_diag_test.py | 12 +- .../linear_operator_identity_test.py | 64 ++++---- .../kernel_tests/linear_operator_test.py | 51 ++++-- .../linalg/python/ops/linear_operator.py | 153 +++++++++++++++--- .../python/ops/linear_operator_addition.py | 2 +- .../python/ops/linear_operator_composition.py | 24 +-- .../linalg/python/ops/linear_operator_diag.py | 10 +- .../python/ops/linear_operator_full_matrix.py | 8 +- .../python/ops/linear_operator_identity.py | 28 ++-- .../python/ops/linear_operator_test_util.py | 22 +-- .../linalg/python/ops/linear_operator_tril.py | 8 +- .../python/ops/linear_operator_udvh_update.py | 22 +-- .../linalg/python/ops/linear_operator_util.py | 4 +- 17 files changed, 288 insertions(+), 154 deletions(-) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py index c1be9d5a77..ae380b5cb2 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py @@ -193,7 +193,7 @@ class AffineLinearOperator(bijector.Bijector): y, expand_batch_dim=False) with ops.control_dependencies(self._maybe_collect_assertions() if self.validate_args else []): - y = self.scale.apply(y) + y = self.scale.matmul(y) y = self._shaper.undo_make_batch_of_event_sample_matrices( y, sample_shape, expand_batch_dim=False) if self.shift is not None: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 15e699906f..b25250d367 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -238,7 +238,7 @@ class MultivariateNormalLinearOperator( if distribution_util.is_diagonal_scale(self.scale): return array_ops.matrix_diag(math_ops.square(self.scale.diag_part())) else: - return self.scale.apply(self.scale.to_dense(), adjoint_arg=True) + return self.scale.matmul(self.scale.to_dense(), adjoint_arg=True) def _variance(self): if distribution_util.is_diagonal_scale(self.scale): @@ -246,10 +246,10 @@ class MultivariateNormalLinearOperator( elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) and self.scale.is_self_adjoint): return array_ops.matrix_diag_part( - self.scale.apply(self.scale.to_dense())) + self.scale.matmul(self.scale.to_dense())) else: return array_ops.matrix_diag_part( - self.scale.apply(self.scale.to_dense(), adjoint_arg=True)) + self.scale.matmul(self.scale.to_dense(), adjoint_arg=True)) def _stddev(self): if distribution_util.is_diagonal_scale(self.scale): @@ -257,10 +257,10 @@ class MultivariateNormalLinearOperator( elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) and self.scale.is_self_adjoint): return math_ops.sqrt(array_ops.matrix_diag_part( - self.scale.apply(self.scale.to_dense()))) + self.scale.matmul(self.scale.to_dense()))) else: return math_ops.sqrt(array_ops.matrix_diag_part( - self.scale.apply(self.scale.to_dense(), adjoint_arg=True))) + self.scale.matmul(self.scale.to_dense(), adjoint_arg=True))) def _mode(self): return self._mean() diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index 346835dd59..fd2c46d94d 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -266,7 +266,7 @@ class VectorLaplaceLinearOperator( if distribution_util.is_diagonal_scale(self.scale): return 2. * array_ops.matrix_diag(math_ops.square(self.scale.diag_part())) else: - return 2. * self.scale.apply(self.scale.to_dense(), adjoint_arg=True) + return 2. * self.scale.matmul(self.scale.to_dense(), adjoint_arg=True) def _variance(self): if distribution_util.is_diagonal_scale(self.scale): @@ -274,10 +274,10 @@ class VectorLaplaceLinearOperator( elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) and self.scale.is_self_adjoint): return array_ops.matrix_diag_part( - 2. * self.scale.apply(self.scale.to_dense())) + 2. * self.scale.matmul(self.scale.to_dense())) else: return 2. * array_ops.matrix_diag_part( - self.scale.apply(self.scale.to_dense(), adjoint_arg=True)) + self.scale.matmul(self.scale.to_dense(), adjoint_arg=True)) def _stddev(self): if distribution_util.is_diagonal_scale(self.scale): @@ -285,10 +285,10 @@ class VectorLaplaceLinearOperator( elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) and self.scale.is_self_adjoint): return np.sqrt(2) * math_ops.sqrt(array_ops.matrix_diag_part( - self.scale.apply(self.scale.to_dense()))) + self.scale.matmul(self.scale.to_dense()))) else: return np.sqrt(2) * math_ops.sqrt(array_ops.matrix_diag_part( - self.scale.apply(self.scale.to_dense(), adjoint_arg=True))) + self.scale.matmul(self.scale.to_dense(), adjoint_arg=True))) def _mode(self): return self._mean() diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_composition_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_composition_test.py index 0585a0ba5a..e2a7f5fbe1 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_composition_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_composition_test.py @@ -77,9 +77,9 @@ class SquareLinearOperatorCompositionTest( # Convert back to Tensor. Needed if use_placeholder, since then we have # already evaluated each matrix to a numpy array. - apply_order_list = list(reversed(matrices)) - mat = ops.convert_to_tensor(apply_order_list[0]) - for other_mat in apply_order_list[1:]: + matmul_order_list = list(reversed(matrices)) + mat = ops.convert_to_tensor(matmul_order_list[0]) + for other_mat in matmul_order_list[1:]: mat = math_ops.matmul(other_mat, mat) return operator, mat, feed_dict @@ -188,9 +188,9 @@ class NonSquareLinearOperatorCompositionTest( # Convert back to Tensor. Needed if use_placeholder, since then we have # already evaluated each matrix to a numpy array. - apply_order_list = list(reversed(matrices)) - mat = ops.convert_to_tensor(apply_order_list[0]) - for other_mat in apply_order_list[1:]: + matmul_order_list = list(reversed(matrices)) + mat = ops.convert_to_tensor(matmul_order_list[0]) + for other_mat in matmul_order_list[1:]: mat = math_ops.matmul(other_mat, mat) return operator, mat, feed_dict diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py index 3bb81a4333..397bfa2215 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py @@ -122,7 +122,7 @@ class LinearOperatorDiagTest( with self.assertRaisesRegexp(ValueError, "must have at least 1 dimension"): linalg.LinearOperatorDiag(1.) - def test_broadcast_apply_and_solve(self): + def test_broadcast_matmul_and_solve(self): # These cannot be done in the automated (base test class) tests since they # test shapes that tf.matmul cannot handle. # In particular, tf.matmul does not broadcast. @@ -130,7 +130,7 @@ class LinearOperatorDiagTest( x = random_ops.random_normal(shape=(2, 2, 3, 4)) # This LinearOperatorDiag will be brodacast to (2, 2, 3, 3) during solve - # and apply with 'x' as the argument. + # and matmul with 'x' as the argument. diag = random_ops.random_uniform(shape=(2, 1, 3)) operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True) self.assertAllEqual((2, 1, 3, 3), operator.shape) @@ -140,10 +140,10 @@ class LinearOperatorDiagTest( mat = array_ops.matrix_diag(diag_broadcast) self.assertAllEqual((2, 2, 3, 3), mat.get_shape()) # being pedantic. - operator_apply = operator.apply(x) - mat_apply = math_ops.matmul(mat, x) - self.assertAllEqual(operator_apply.get_shape(), mat_apply.get_shape()) - self.assertAllClose(*sess.run([operator_apply, mat_apply])) + operator_matmul = operator.matmul(x) + mat_matmul = math_ops.matmul(mat, x) + self.assertAllEqual(operator_matmul.get_shape(), mat_matmul.get_shape()) + self.assertAllClose(*sess.run([operator_matmul, mat_matmul])) operator_solve = operator.solve(x) mat_solve = linalg_ops.matrix_solve(mat, x) diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_identity_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_identity_test.py index 36a255f3d5..5faf2c432b 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_identity_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_identity_test.py @@ -77,14 +77,14 @@ class LinearOperatorIdentityTest( operator = linalg_lib.LinearOperatorIdentity(num_rows=2) operator.assert_self_adjoint().run() # Should not fail - def test_float16_apply(self): + def test_float16_matmul(self): # float16 cannot be tested by base test class because tf.matrix_solve does # not work with float16. with self.test_session(): operator = linalg_lib.LinearOperatorIdentity( num_rows=2, dtype=dtypes.float16) x = rng.randn(2, 3).astype(np.float16) - y = operator.apply(x) + y = operator.matmul(x) self.assertAllClose(x, y.eval()) def test_non_scalar_num_rows_raises_static(self): @@ -147,7 +147,7 @@ class LinearOperatorIdentityTest( operator = linalg_lib.LinearOperatorIdentity(num_rows=2) x = rng.randn(3, 3).astype(np.float32) with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): - operator.apply(x) + operator.matmul(x) def test_wrong_matrix_dimensions_raises_dynamic(self): num_rows = array_ops.placeholder(dtypes.int32) @@ -156,7 +156,7 @@ class LinearOperatorIdentityTest( with self.test_session(): operator = linalg_lib.LinearOperatorIdentity( num_rows, assert_proper_shapes=True) - y = operator.apply(x) + y = operator.matmul(x) with self.assertRaisesOpError("Incompatible.*dimensions"): y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)}) @@ -168,11 +168,11 @@ class LinearOperatorIdentityTest( x = random_ops.random_normal(shape=(1, 2, 3, 4)) operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype) - operator_apply = operator.apply(x) + operator_matmul = operator.matmul(x) expected = x - self.assertAllEqual(operator_apply.get_shape(), expected.get_shape()) - self.assertAllClose(*sess.run([operator_apply, expected])) + self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape()) + self.assertAllClose(*sess.run([operator_matmul, expected])) def test_default_batch_shape_broadcasts_with_everything_dynamic(self): # These cannot be done in the automated (base test class) tests since they @@ -182,15 +182,15 @@ class LinearOperatorIdentityTest( x = array_ops.placeholder(dtypes.float32) operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype) - operator_apply = operator.apply(x) + operator_matmul = operator.matmul(x) expected = x feed_dict = {x: rng.randn(1, 2, 3, 4)} self.assertAllClose( - *sess.run([operator_apply, expected], feed_dict=feed_dict)) + *sess.run([operator_matmul, expected], feed_dict=feed_dict)) - def test_broadcast_apply_static_shapes(self): + def test_broadcast_matmul_static_shapes(self): # These cannot be done in the automated (base test class) tests since they # test shapes that tf.batch_matmul cannot handle. # In particular, tf.batch_matmul does not broadcast. @@ -204,14 +204,14 @@ class LinearOperatorIdentityTest( # Batch matrix of zeros with the broadcast shape of x and operator. zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype) - # Expected result of apply and solve. + # Expected result of matmul and solve. expected = x + zeros - operator_apply = operator.apply(x) - self.assertAllEqual(operator_apply.get_shape(), expected.get_shape()) - self.assertAllClose(*sess.run([operator_apply, expected])) + operator_matmul = operator.matmul(x) + self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape()) + self.assertAllClose(*sess.run([operator_matmul, expected])) - def test_broadcast_apply_dynamic_shapes(self): + def test_broadcast_matmul_dynamic_shapes(self): # These cannot be done in the automated (base test class) tests since they # test shapes that tf.batch_matmul cannot handle. # In particular, tf.batch_matmul does not broadcast. @@ -229,12 +229,12 @@ class LinearOperatorIdentityTest( # Batch matrix of zeros with the broadcast shape of x and operator. zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype) - # Expected result of apply and solve. + # Expected result of matmul and solve. expected = x + zeros - operator_apply = operator.apply(x) + operator_matmul = operator.matmul(x) self.assertAllClose( - *sess.run([operator_apply, expected], feed_dict=feed_dict)) + *sess.run([operator_matmul, expected], feed_dict=feed_dict)) def test_is_x_flags(self): # The is_x flags are by default all True. @@ -332,7 +332,7 @@ class LinearOperatorScaledIdentityTest( with self.assertRaisesOpError("not self-adjoint"): operator.assert_self_adjoint().run() - def test_float16_apply(self): + def test_float16_matmul(self): # float16 cannot be tested by base test class because tf.matrix_solve does # not work with float16. with self.test_session(): @@ -340,7 +340,7 @@ class LinearOperatorScaledIdentityTest( operator = linalg_lib.LinearOperatorScaledIdentity( num_rows=2, multiplier=multiplier) x = rng.randn(2, 3).astype(np.float16) - y = operator.apply(x) + y = operator.matmul(x) self.assertAllClose(multiplier[..., None, None] * x, y.eval()) def test_non_scalar_num_rows_raises_static(self): @@ -354,7 +354,7 @@ class LinearOperatorScaledIdentityTest( num_rows=2, multiplier=2.2) x = rng.randn(3, 3).astype(np.float32) with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): - operator.apply(x) + operator.matmul(x) def test_wrong_matrix_dimensions_raises_dynamic(self): num_rows = array_ops.placeholder(dtypes.int32) @@ -363,11 +363,11 @@ class LinearOperatorScaledIdentityTest( with self.test_session(): operator = linalg_lib.LinearOperatorScaledIdentity( num_rows, multiplier=[1., 2], assert_proper_shapes=True) - y = operator.apply(x) + y = operator.matmul(x) with self.assertRaisesOpError("Incompatible.*dimensions"): y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)}) - def test_broadcast_apply_and_solve(self): + def test_broadcast_matmul_and_solve(self): # These cannot be done in the automated (base test class) tests since they # test shapes that tf.batch_matmul cannot handle. # In particular, tf.batch_matmul does not broadcast. @@ -383,11 +383,11 @@ class LinearOperatorScaledIdentityTest( # Batch matrix of zeros with the broadcast shape of x and operator. zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype) - # Test apply + # Test matmul expected = x * 2.2 + zeros - operator_apply = operator.apply(x) - self.assertAllEqual(operator_apply.get_shape(), expected.get_shape()) - self.assertAllClose(*sess.run([operator_apply, expected])) + operator_matmul = operator.matmul(x) + self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape()) + self.assertAllClose(*sess.run([operator_matmul, expected])) # Test solve expected = x / 2.2 + zeros @@ -395,7 +395,7 @@ class LinearOperatorScaledIdentityTest( self.assertAllEqual(operator_solve.get_shape(), expected.get_shape()) self.assertAllClose(*sess.run([operator_solve, expected])) - def test_broadcast_apply_and_solve_scalar_scale_multiplier(self): + def test_broadcast_matmul_and_solve_scalar_scale_multiplier(self): # These cannot be done in the automated (base test class) tests since they # test shapes that tf.batch_matmul cannot handle. # In particular, tf.batch_matmul does not broadcast. @@ -409,11 +409,11 @@ class LinearOperatorScaledIdentityTest( operator = linalg_lib.LinearOperatorScaledIdentity( num_rows=3, multiplier=2.2) - # Test apply + # Test matmul expected = x * 2.2 - operator_apply = operator.apply(x) - self.assertAllEqual(operator_apply.get_shape(), expected.get_shape()) - self.assertAllClose(*sess.run([operator_apply, expected])) + operator_matmul = operator.matmul(x) + self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape()) + self.assertAllClose(*sess.run([operator_matmul, expected])) # Test solve expected = x / 2.2 diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py index d24388fce3..78a4822c17 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -54,12 +55,12 @@ class LinearOperatorShape(linalg.LinearOperator): def _shape_tensor(self): return constant_op.constant(self._stored_shape, dtype=dtypes.int32) - def _apply(self): + def _matmul(self): raise NotImplementedError("Not needed for this test.") -class LinearOperatorApplyOnly(linalg.LinearOperator): - """LinearOperator that simply wraps a [batch] matrix and implements apply.""" +class LinearOperatorMatmulSolve(linalg.LinearOperator): + """LinearOperator that wraps a [batch] matrix and implements matmul/solve.""" def __init__(self, matrix, @@ -68,8 +69,8 @@ class LinearOperatorApplyOnly(linalg.LinearOperator): is_positive_definite=None, is_square=None): self._matrix = ops.convert_to_tensor(matrix, name="matrix") - super(LinearOperatorApplyOnly, self).__init__( - dtype=matrix.dtype, + super(LinearOperatorMatmulSolve, self).__init__( + dtype=self._matrix.dtype, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, @@ -81,10 +82,16 @@ class LinearOperatorApplyOnly(linalg.LinearOperator): def _shape_tensor(self): return array_ops.shape(self._matrix) - def _apply(self, x, adjoint=False, adjoint_arg=False): + def _matmul(self, x, adjoint=False, adjoint_arg=False): + x = ops.convert_to_tensor(x, name="x") return math_ops.matmul( self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg) + def _solve(self, rhs, adjoint=False, adjoint_arg=False): + rhs = ops.convert_to_tensor(rhs, name="rhs") + assert not adjoint_arg, "Not implemented for this test class." + return linalg_ops.matrix_solve(self._matrix, rhs, adjoint=adjoint) + class LinearOperatorTest(test.TestCase): @@ -122,7 +129,7 @@ class LinearOperatorTest(test.TestCase): def test_generic_to_dense_method_non_square_matrix_static(self): matrix = rng.randn(2, 3, 4) - operator = LinearOperatorApplyOnly(matrix) + operator = LinearOperatorMatmulSolve(matrix) with self.test_session(): operator_dense = operator.to_dense() self.assertAllEqual((2, 3, 4), operator_dense.get_shape()) @@ -131,12 +138,30 @@ class LinearOperatorTest(test.TestCase): def test_generic_to_dense_method_non_square_matrix_tensor(self): matrix = rng.randn(2, 3, 4) matrix_ph = array_ops.placeholder(dtypes.float64) - operator = LinearOperatorApplyOnly(matrix_ph) + operator = LinearOperatorMatmulSolve(matrix_ph) with self.test_session(): operator_dense = operator.to_dense() self.assertAllClose( matrix, operator_dense.eval(feed_dict={matrix_ph: matrix})) + def test_matvec(self): + matrix = [[1., 0], [0., 2.]] + operator = LinearOperatorMatmulSolve(matrix) + x = [1., 1.] + with self.test_session(): + y = operator.matvec(x) + self.assertAllEqual((2,), y.get_shape()) + self.assertAllClose([1., 2.], y.eval()) + + def test_solvevec(self): + matrix = [[1., 0], [0., 2.]] + operator = LinearOperatorMatmulSolve(matrix) + y = [1., 1.] + with self.test_session(): + x = operator.solvevec(y) + self.assertAllEqual((2,), x.get_shape()) + self.assertAllClose([1., 1 / 2.], x.eval()) + def test_is_square_set_to_true_for_square_static_shapes(self): operator = LinearOperatorShape(shape=(2, 4, 4)) self.assertTrue(operator.is_square) @@ -152,11 +177,11 @@ class LinearOperatorTest(test.TestCase): def test_is_square_set_inconsistent_with_other_hints_raises(self): with self.assertRaisesRegexp(ValueError, "is always square"): matrix = array_ops.placeholder(dtypes.float32) - LinearOperatorApplyOnly(matrix, is_non_singular=True, is_square=False) + LinearOperatorMatmulSolve(matrix, is_non_singular=True, is_square=False) with self.assertRaisesRegexp(ValueError, "is always square"): matrix = array_ops.placeholder(dtypes.float32) - LinearOperatorApplyOnly( + LinearOperatorMatmulSolve( matrix, is_positive_definite=True, is_square=False) def test_non_square_operators_raise_on_determinant_and_solve(self): @@ -170,16 +195,16 @@ class LinearOperatorTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "is always square"): matrix = array_ops.placeholder(dtypes.float32) - LinearOperatorApplyOnly( + LinearOperatorMatmulSolve( matrix, is_positive_definite=True, is_square=False) def test_is_square_manual_set_works(self): matrix = array_ops.placeholder(dtypes.float32) # Default is None. - operator = LinearOperatorApplyOnly(matrix) + operator = LinearOperatorMatmulSolve(matrix) self.assertEqual(None, operator.is_square) # Set to True - operator = LinearOperatorApplyOnly(matrix, is_square=True) + operator = LinearOperatorMatmulSolve(matrix, is_square=True) self.assertTrue(operator.is_square) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator.py b/tensorflow/contrib/linalg/python/ops/linear_operator.py index 605ab1511d..6cdfa86189 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator.py @@ -51,8 +51,8 @@ class LinearOperator(object): To enable a public method, subclasses should implement the leading-underscore version of the method. The argument signature should be identical except for the omission of `name="..."`. For example, to enable - `apply(x, adjoint=False, name="apply")` a subclass should implement - `_apply(x, adjoint=False)`. + `matmul(x, adjoint=False, name="matmul")` a subclass should implement + `_matmul(x, adjoint=False)`. #### Performance contract @@ -72,7 +72,7 @@ class LinearOperator(object): An example is: - `x` is a batch matrix with compatible shape for `apply` if + `x` is a batch matrix with compatible shape for `matmul` if ``` operator.shape = [B1,...,Bb] + [M, N], b >= 0, @@ -109,7 +109,7 @@ class LinearOperator(object): x = ... Shape [2, 4, 5] Tensor - operator.apply(x) + operator.matmul(x) ==> Shape [2, 4, 5] Tensor ``` @@ -151,7 +151,7 @@ class LinearOperator(object): **Subclasses should copy-paste this `__init__` documentation.** Args: - dtype: The type of the this `LinearOperator`. Arguments to `apply` and + dtype: The type of the this `LinearOperator`. Arguments to `matmul` and `solve` will have to be this type. graph_parents: Python list of graph prerequisites of this `LinearOperator` Typically tensors that are passed during initialization. @@ -577,11 +577,25 @@ class LinearOperator(object): (self.dtype, arg.dtype, arg)) @abc.abstractmethod - def _apply(self, x, adjoint=False, adjoint_arg=False): - raise NotImplementedError("_apply is not implemented.") + def _matmul(self, x, adjoint=False, adjoint_arg=False): + raise NotImplementedError("_matmul is not implemented.") - def apply(self, x, adjoint=False, adjoint_arg=False, name="apply"): - """Transform `x` with left multiplication: `x --> Ax`. + def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. + + ```python + # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] + operator = LinearOperator(...) + operator.shape = [..., M, N] + + X = ... # shape [..., N, R], batch matrix, R > 0. + + Y = operator.matmul(X) + Y.shape + ==> [..., M, R] + + Y[..., :, r] = sum_j A[..., :, j] X[j, r] + ``` Args: x: `Tensor` with compatible shape and same `dtype` as `self`. @@ -602,7 +616,46 @@ class LinearOperator(object): arg_dim = -1 if adjoint_arg else -2 self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) - return self._apply(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + def _matvec(self, x, adjoint=False): + x_mat = array_ops.expand_dims(x, axis=-1) + y_mat = self.matmul(x_mat, adjoint=adjoint) + return array_ops.squeeze(y_mat, axis=-1) + + def matvec(self, x, adjoint=False, name="matvec"): + """Transform [batch] vector `x` with left multiplication: `x --> Ax`. + + ```python + # Make an operator acting like batch matric A. Assume A.shape = [..., M, N] + operator = LinearOperator(...) + + X = ... # shape [..., N], batch vector + + Y = operator.matvec(X) + Y.shape + ==> [..., M] + + Y[..., :] = sum_j A[..., :, j] X[..., j] + ``` + + Args: + x: `Tensor` with compatible shape and same `dtype` as `self`. + `x` is treated as a [batch] vector meaning for every set of leading + dimensions, the last dimension defines a vector. + See class docstring for definition of compatibility. + adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. + name: A name for this `Op. + + Returns: + A `Tensor` with shape `[..., M]` and same `dtype` as `self`. + """ + with self._name_scope(name, values=[x]): + x = ops.convert_to_tensor(x, name="x") + self._check_input_dtype(x) + self_dim = -2 if adjoint else -1 + self.shape[self_dim].assert_is_compatible_with(x.get_shape()[-1]) + return self._matvec(x, adjoint=adjoint) def _determinant(self): logging.warn( @@ -675,30 +728,33 @@ class LinearOperator(object): self._get_cached_dense_matrix(), rhs, adjoint=adjoint) def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): - """Solve `R` (batch) systems of equations with best effort: `A X = rhs`. + """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. - The solution may not be exact, and in this case it will be close in some - sense (see class docstring for details). + The returned `Tensor` will be close to an exact solution if `A` is well + conditioned. Otherwise closeness will vary. See class docstring for details. Examples: ```python - # Create an operator acting like a 10 x 2 x 2 matrix. + # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] operator = LinearOperator(...) - operator.shape # = 10 x 2 x 2 + operator.shape = [..., M, N] - # Solve one linear system (R = 1) for every member of the length 10 batch. - RHS = ... # shape 10 x 2 x 1 - X = operator.solve(RHS) # shape 10 x 2 x 1 + # Solve R > 0 linear systems for every member of the batch. + RHS = ... # shape [..., M, R] - # Solve five linear systems (R = 5) for every member of the length 10 batch. - RHS = ... # shape 10 x 2 x 5 X = operator.solve(RHS) - X[3, :, 2] # Solution to the linear system A[3, :, :] X = RHS[3, :, 2] + # X[..., :, r] is the solution to the r'th linear system + # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] + + operator.matmul(X) + ==> RHS ``` Args: rhs: `Tensor` with same `dtype` as this operator and compatible shape. + `rhs` is treated like a [batch] matrix meaning for every set of leading + dimensions, the last two dimensions defines a matrix. See class docstring for definition of compatibility. adjoint: Python `bool`. If `True`, solve the system involving the adjoint of this `LinearOperator`: `A^H X = rhs`. @@ -730,6 +786,59 @@ class LinearOperator(object): return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) + def _solvevec(self, rhs, adjoint=False): + """Default implementation of _solvevec.""" + rhs_mat = array_ops.expand_dims(rhs, axis=-1) + solution_mat = self.solve(rhs_mat, adjoint=adjoint) + return array_ops.squeeze(solution_mat, axis=-1) + + def solvevec(self, rhs, adjoint=False, name="solve"): + """Solve single equation with best effort: `A X = rhs`. + + The returned `Tensor` will be close to an exact solution if `A` is well + conditioned. Otherwise closeness will vary. See class docstring for details. + + Examples: + + ```python + # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] + operator = LinearOperator(...) + operator.shape = [..., M, N] + + # Solve one linear system for every member of the batch. + RHS = ... # shape [..., M] + + X = operator.solvevec(RHS) + # X is the solution to the linear system + # sum_j A[..., :, j] X[..., j] = RHS[..., :] + + operator.matvec(X) + ==> RHS + ``` + + Args: + rhs: `Tensor` with same `dtype` as this operator. + `rhs` is treated like a [batch] vector meaning for every set of leading + dimensions, the last dimension defines a vector. See class docstring + for definition of compatibility regarding batch dimensions. + adjoint: Python `bool`. If `True`, solve the system involving the adjoint + of this `LinearOperator`: `A^H X = rhs`. + name: A name scope to use for ops added by this method. + + Returns: + `Tensor` with shape `[...,N]` and same `dtype` as `rhs`. + + Raises: + NotImplementedError: If `self.is_non_singular` or `is_square` is False. + """ + with self._name_scope(name, values=[rhs]): + rhs = ops.convert_to_tensor(rhs, name="rhs") + self._check_input_dtype(rhs) + self_dim = -1 if adjoint else -2 + self.shape[self_dim].assert_is_compatible_with(rhs.get_shape()[-1]) + + return self._solvevec(rhs, adjoint=adjoint) + def _to_dense(self): """Generic and often inefficient implementation. Override often.""" logging.warn("Using (possibly slow) default implementation of to_dense." @@ -745,7 +854,7 @@ class LinearOperator(object): n = self.domain_dimension_tensor() eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype) - return self.apply(eye) + return self.matmul(eye) def to_dense(self, name="to_dense"): """Return a dense (batch) matrix representing this operator.""" diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py index 7617e1b591..16c4c6e6d6 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py @@ -43,7 +43,7 @@ def add_operators(operators, Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of operators `[B1, B2,...]` such that - ```sum_k Ak.apply(x) = sum_k Bk.apply(x).``` + ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).``` The operators `Bk` result by adding some of the `Ak`, as allowed by `addition_tiers`. diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py index 550c630497..9dec621ab2 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py @@ -67,7 +67,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator): ==> scalar Tensor x = ... Shape [2, 4] Tensor - operator.apply(x) + operator.matmul(x) ==> Shape [2, 4] Tensor # Create a [2, 3] batch of 4 x 5 linear operators. @@ -83,7 +83,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator): # Create a shape [2, 3, 6, 2] vector. x = tf.random_normal(shape=[2, 3, 6, 2]) - operator.apply(x) + operator.matmul(x) ==> Shape [2, 3, 4, 2] Tensor ``` @@ -117,8 +117,8 @@ class LinearOperatorComposition(linear_operator.LinearOperator): r"""Initialize a `LinearOperatorComposition`. `LinearOperatorComposition` is initialized with a list of operators - `[op_1,...,op_J]`. For the `apply` method to be well defined, the - composition `op_i.apply(op_{i+1}(x))` must be defined. Other methods have + `[op_1,...,op_J]`. For the `matmul` method to be well defined, the + composition `op_i.matmul(op_{i+1}(x))` must be defined. Other methods have similar constraints. Args: @@ -228,19 +228,19 @@ class LinearOperatorComposition(linear_operator.LinearOperator): return array_ops.concat((batch_shape, matrix_shape), 0) - def _apply(self, x, adjoint=False, adjoint_arg=False): + def _matmul(self, x, adjoint=False, adjoint_arg=False): # If self.operators = [A, B], and not adjoint, then - # apply_order_list = [B, A]. - # As a result, we return A.apply(B.apply(x)) + # matmul_order_list = [B, A]. + # As a result, we return A.matmul(B.matmul(x)) if adjoint: - apply_order_list = self.operators + matmul_order_list = self.operators else: - apply_order_list = list(reversed(self.operators)) + matmul_order_list = list(reversed(self.operators)) - result = apply_order_list[0].apply( + result = matmul_order_list[0].matmul( x, adjoint=adjoint, adjoint_arg=adjoint_arg) - for operator in apply_order_list[1:]: - result = operator.apply(result, adjoint=adjoint) + for operator in matmul_order_list[1:]: + result = operator.matmul(result, adjoint=adjoint) return result def _determinant(self): diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py index d81dea6514..56bc967706 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py @@ -56,7 +56,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator): ==> scalar Tensor x = ... Shape [2, 4] Tensor - operator.apply(x) + operator.matmul(x) ==> Shape [2, 4] Tensor # Create a [2, 3] batch of 4 x 4 linear operators. @@ -68,13 +68,13 @@ class LinearOperatorDiag(linear_operator.LinearOperator): # operator.batch_shape = [2, 3]. y = tf.random_normal(shape=[2, 1, 4, 2]) x = operator.solve(y) - ==> operator.apply(x) = y + ==> operator.matmul(x) = y ``` #### Shape compatibility This operator acts on [batch] matrix with compatible shape. - `x` is a batch matrix with compatible shape for `apply` and `solve` if + `x` is a batch matrix with compatible shape for `matmul` and `solve` if ``` operator.shape = [B1,...,Bb] + [N, N], with b >= 0 @@ -87,7 +87,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator): Suppose `operator` is a `LinearOperatorDiag` of shape `[N, N]`, and `x.shape = [N, R]`. Then - * `operator.apply(x)` involves `N * R` multiplications. + * `operator.matmul(x)` involves `N * R` multiplications. * `operator.solve(x)` involves `N` divisions and `N * R` multiplications. * `operator.determinant()` involves a size `N` `reduce_prod`. @@ -213,7 +213,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator): "This diagonal operator contained non-zero imaginary values. " " Thus it was not self-adjoint.")) - def _apply(self, x, adjoint=False, adjoint_arg=False): + def _matmul(self, x, adjoint=False, adjoint_arg=False): diag_term = math_ops.conj(self._diag) if adjoint else self._diag x = linear_operator_util.matrix_adjoint(x) if adjoint_arg else x diag_mat = array_ops.expand_dims(diag_term, -1) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_full_matrix.py b/tensorflow/contrib/linalg/python/ops/linear_operator_full_matrix.py index 0f245e609b..67889511cb 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_full_matrix.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_full_matrix.py @@ -51,7 +51,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): ==> scalar Tensor x = ... Shape [2, 4] Tensor - operator.apply(x) + operator.matmul(x) ==> Shape [2, 4] Tensor # Create a [2, 3] batch of 4 x 4 linear operators. @@ -62,7 +62,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): #### Shape compatibility This operator acts on [batch] matrix with compatible shape. - `x` is a batch matrix with compatible shape for `apply` and `solve` if + `x` is a batch matrix with compatible shape for `matmul` and `solve` if ``` operator.shape = [B1,...,Bb] + [M, N], with b >= 0 @@ -81,7 +81,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): In all cases, suppose `operator` is a `LinearOperatorFullMatrix` of shape `[M, N]`, and `x.shape = [N, R]`. Then - * `operator.apply(x)` is `O(M * N * R)`. + * `operator.matmul(x)` is `O(M * N * R)`. * If `M=N`, `operator.solve(x)` is `O(N^3 * R)`. * If `M=N`, `operator.determinant()` is `O(N^3)`. @@ -167,7 +167,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): def _shape_tensor(self): return array_ops.shape(self._matrix) - def _apply(self, x, adjoint=False, adjoint_arg=False): + def _matmul(self, x, adjoint=False, adjoint_arg=False): return math_ops.matmul( self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py b/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py index d595442c70..acba1c7035 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py @@ -116,7 +116,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): ==> 0. x = ... Shape [2, 4] Tensor - operator.apply(x) + operator.matmul(x) ==> Shape [2, 4] Tensor, same as x. y = tf.random_normal(shape=[3, 2, 4]) @@ -141,20 +141,20 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): # to detect that no broadcast is necessary because both x and the operator # have statically defined shape. x = ... Shape [2, 2, 3] - operator.apply(x) + operator.matmul(x) ==> Shape [2, 2, 3] Tensor, same as x # Here the operator and x have different batch_shape, and are broadcast. # This requires a copy, since the output is different size than the input. x = ... Shape [1, 2, 3] - operator.apply(x) + operator.matmul(x) ==> Shape [2, 2, 3] Tensor, equal to [x, x] ``` ### Shape compatibility This operator acts on [batch] matrix with compatible shape. - `x` is a batch matrix with compatible shape for `apply` and `solve` if + `x` is a batch matrix with compatible shape for `matmul` and `solve` if ``` operator.shape = [B1,...,Bb] + [N, N], with b >= 0 @@ -166,14 +166,14 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): If `batch_shape` initialization arg is `None`: - * `operator.apply(x)` is `O(1)` + * `operator.matmul(x)` is `O(1)` * `operator.solve(x)` is `O(1)` * `operator.determinant()` is `O(1)` If `batch_shape` initialization arg is provided, and static checks cannot rule out the need to broadcast: - * `operator.apply(x)` is `O(D1*...*Dd*N*R)` + * `operator.matmul(x)` is `O(D1*...*Dd*N*R)` * `operator.solve(x)` is `O(D1*...*Dd*N*R)` * `operator.determinant()` is `O(B1*...*Bb)` @@ -334,7 +334,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) return x + zeros - def _apply(self, x, adjoint=False, adjoint_arg=False): + def _matmul(self, x, adjoint=False, adjoint_arg=False): # Note that adjoint has no effect since this matrix is self-adjoint. x = linear_operator_util.matrix_adjoint(x) if adjoint_arg else x if self._assert_proper_shapes: @@ -350,7 +350,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) def _solve(self, rhs, adjoint=False, adjoint_arg=False): - return self._apply(rhs, adjoint_arg=adjoint_arg) + return self._matmul(rhs, adjoint_arg=adjoint_arg) def _diag_part(self): return self._ones_diag() @@ -468,7 +468,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): ==> 2 * Log[3] x = ... Shape [2, 4] Tensor - operator.apply(x) + operator.matmul(x) ==> 3 * x y = tf.random_normal(shape=[3, 2, 4]) @@ -486,19 +486,19 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): [0., 5.]]] x = ... Shape [2, 2, 3] - operator.apply(x) + operator.matmul(x) ==> 5 * x # Here the operator and x have different batch_shape, and are broadcast. x = ... Shape [1, 2, 3] - operator.apply(x) + operator.matmul(x) ==> 5 * x ``` ### Shape compatibility This operator acts on [batch] matrix with compatible shape. - `x` is a batch matrix with compatible shape for `apply` and `solve` if + `x` is a batch matrix with compatible shape for `matmul` and `solve` if ``` operator.shape = [B1,...,Bb] + [N, N], with b >= 0 @@ -508,7 +508,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): ### Performance - * `operator.apply(x)` is `O(D1*...*Dd*N*R)` + * `operator.matmul(x)` is `O(D1*...*Dd*N*R)` * `operator.solve(x)` is `O(D1*...*Dd*N*R)` * `operator.determinant()` is `O(D1*...*Dd)` @@ -628,7 +628,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): imag_multiplier, message="LinearOperator was not self-adjoint") - def _apply(self, x, adjoint=False, adjoint_arg=False): + def _matmul(self, x, adjoint=False, adjoint_arg=False): x = linear_operator_util.matrix_adjoint(x) if adjoint_arg else x if adjoint: matrix = self._multiplier_matrix_conj diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py index c8bc62eeef..b2d7b10157 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py @@ -116,7 +116,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): @abc.abstractmethod def _make_x(self, operator, adjoint): - """Make an 'x' appropriate for calling operator.apply(x). + """Make an 'x' appropriate for calling operator.matmul(x). Args: operator: A `LinearOperator` @@ -208,8 +208,8 @@ class LinearOperatorDerivedClassTest(test.TestCase): feed_dict=feed_dict) self.assertAC(op_log_abs_det_v, mat_log_abs_det_v) - def test_apply(self): - self._skip_if_tests_to_skip_contains("apply") + def test_matmul(self): + self._skip_if_tests_to_skip_contains("matmul") for use_placeholder in False, True: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: @@ -222,18 +222,18 @@ class LinearOperatorDerivedClassTest(test.TestCase): x = self._make_x(operator, adjoint=adjoint) # If adjoint_arg, compute A X^H^H = A X. if adjoint_arg: - op_apply = operator.apply( + op_matmul = operator.matmul( linear_operator_util.matrix_adjoint(x), adjoint=adjoint, adjoint_arg=adjoint_arg) else: - op_apply = operator.apply(x, adjoint=adjoint) - mat_apply = math_ops.matmul(mat, x, adjoint_a=adjoint) + op_matmul = operator.matmul(x, adjoint=adjoint) + mat_matmul = math_ops.matmul(mat, x, adjoint_a=adjoint) if not use_placeholder: self.assertAllEqual( - op_apply.get_shape(), mat_apply.get_shape()) - op_apply_v, mat_apply_v = sess.run([op_apply, mat_apply], - feed_dict=feed_dict) - self.assertAC(op_apply_v, mat_apply_v) + op_matmul.get_shape(), mat_matmul.get_shape()) + op_matmul_v, mat_matmul_v = sess.run( + [op_matmul, mat_matmul], feed_dict=feed_dict) + self.assertAC(op_matmul_v, mat_matmul_v) def test_solve(self): self._skip_if_tests_to_skip_contains("solve") @@ -376,7 +376,7 @@ class NonSquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest): "_make_rhs not implemented because we don't test solve") def _make_x(self, operator, adjoint): - # Return the number of systems for the argument 'x' for .apply(x) + # Return the number of systems for the argument 'x' for .matmul(x) r = self._get_num_systems(operator) # If operator.shape = [B1,...,Bb, M, N] this returns a random matrix of # shape [B1,...,Bb, N, R], R = 1 or 2. diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py index 6d4033c2a3..8a152a9b47 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py @@ -57,7 +57,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator): ==> scalar Tensor x = ... Shape [2, 4] Tensor - operator.apply(x) + operator.matmul(x) ==> Shape [2, 4] Tensor # Create a [2, 3] batch of 4 x 4 linear operators. @@ -68,7 +68,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator): #### Shape compatibility This operator acts on [batch] matrix with compatible shape. - `x` is a batch matrix with compatible shape for `apply` and `solve` if + `x` is a batch matrix with compatible shape for `matmul` and `solve` if ``` operator.shape = [B1,...,Bb] + [N, N], with b >= 0 @@ -80,7 +80,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator): Suppose `operator` is a `LinearOperatorTriL` of shape `[N, N]`, and `x.shape = [N, R]`. Then - * `operator.apply(x)` involves `N^2 * R` multiplications. + * `operator.matmul(x)` involves `N^2 * R` multiplications. * `operator.solve(x)` involves `N * R` size `N` back-substitutions. * `operator.determinant()` involves a size `N` `reduce_prod`. @@ -182,7 +182,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator): self._diag, message="Singular operator: Diagonal contained zero values.") - def _apply(self, x, adjoint=False, adjoint_arg=False): + def _matmul(self, x, adjoint=False, adjoint_arg=False): return math_ops.matmul( self._tril, x, adjoint_a=adjoint, adjoint_b=adjoint_arg) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py b/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py index 4ca77ab147..546d899e74 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py @@ -74,18 +74,18 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator): operator.shape ==> [3, 3] - operator.log_determinant() + operator.log_abs_determinant() ==> scalar Tensor x = ... Shape [3, 4] Tensor - operator.apply(x) + operator.matmul(x) ==> Shape [3, 4] Tensor ``` ### Shape compatibility This operator acts on [batch] matrix with compatible shape. - `x` is a batch matrix with compatible shape for `apply` and `solve` if + `x` is a batch matrix with compatible shape for `matmul` and `solve` if ``` operator.shape = [B1,...,Bb] + [M, N], with b >= 0 @@ -95,15 +95,15 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator): ### Performance Suppose `operator` is a `LinearOperatorUDVHUpdate` of shape `[M, N]`, - made from a rank `K` update of `base_operator` which performs `.apply(x)` on - `x` having `x.shape = [N, R]` with `O(L_apply*N*R)` complexity (and similarly + made from a rank `K` update of `base_operator` which performs `.matmul(x)` on + `x` having `x.shape = [N, R]` with `O(L_matmul*N*R)` complexity (and similarly for `solve`, `determinant`. Then, if `x.shape = [N, R]`, - * `operator.apply(x)` is `O(L_apply*N*R + K*N*R)` + * `operator.matmul(x)` is `O(L_matmul*N*R + K*N*R)` and if `M = N`, - * `operator.solve(x)` is `O(L_apply*N*R + N*K*R + K^2*R + K^3)` + * `operator.solve(x)` is `O(L_matmul*N*R + N*K*R + K^2*R + K^3)` * `operator.determinant()` is `O(L_determinant + L_solve*N*K + K^2*N + K^3)` If instead `operator` and `x` have shape `[B1,...,Bb, M, N]` and @@ -348,22 +348,22 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator): return array_ops.concat( [batch_shape, self.base_operator.shape_tensor()[-2:]], axis=0) - def _apply(self, x, adjoint=False, adjoint_arg=False): + def _matmul(self, x, adjoint=False, adjoint_arg=False): u = self.u v = self.v l = self.base_operator d = self.diag_operator - leading_term = l.apply(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + leading_term = l.matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) if adjoint: uh_x = math_ops.matmul(u, x, adjoint_a=True, adjoint_b=adjoint_arg) - d_uh_x = d.apply(uh_x, adjoint=adjoint) + d_uh_x = d.matmul(uh_x, adjoint=adjoint) v_d_uh_x = math_ops.matmul(v, d_uh_x) return leading_term + v_d_uh_x else: vh_x = math_ops.matmul(v, x, adjoint_a=True, adjoint_b=adjoint_arg) - d_vh_x = d.apply(vh_x, adjoint=adjoint) + d_vh_x = d.matmul(vh_x, adjoint=adjoint) u_d_vh_x = math_ops.matmul(u, d_vh_x) return leading_term + u_d_vh_x diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py index 9f8cb23169..2659bd32e9 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py @@ -69,10 +69,10 @@ def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"): def assert_compatible_matrix_dimensions(operator, x): - """Assert that an argument to solve/apply has proper domain dimension. + """Assert that an argument to solve/matmul has proper domain dimension. If `operator.shape[-2:] = [M, N]`, and `x.shape[-2:] = [Q, R]`, then - `operator.apply(x)` is defined only if `N = Q`. This `Op` returns an + `operator.matmul(x)` is defined only if `N = Q`. This `Op` returns an `Assert` that "fires" if this is not the case. Static checks are already done by the base class `LinearOperator`. -- GitLab From 968ac01d7c12e192f7f311775b7b58f13075dc46 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 May 2017 11:40:03 -0700 Subject: [PATCH 679/697] Adds a new rule in tensorflow/core/BUILD to enable no-rtti build for selective registration android library. PiperOrigin-RevId: 156208734 --- tensorflow/core/BUILD | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a0d56df4aa..6056d50587 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1019,6 +1019,27 @@ cc_library( alwayslink = 1, ) +# Android library for use with the SELECTIVE_REGISTRATION feature with +# no proto_rtti. +cc_library( + name = "android_tensorflow_lib_selective_registration_nortti", + srcs = if_android(["//tensorflow/core:android_srcs"]), + copts = tf_copts() + tf_opts_nortti_if_android() + [ + "-Os", + "-DSUPPORT_SELECTIVE_REGISTRATION", + ], + tags = [ + "manual", + "notap", + ], + visibility = ["//visibility:public"], + deps = [ + ":protos_cc", + "//third_party/eigen3", + ], + alwayslink = 1, +) + filegroup( name = "android_op_registrations_and_gradients", srcs = glob( -- GitLab From 194b1644cba446d36c6d192771b11245847db7cc Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Tue, 16 May 2017 11:55:45 -0700 Subject: [PATCH 680/697] Add topological sort to grappler util and to the autoparallel optimizer. The graph importer requires graph to be topologically sorted for shape inference. #9906 PiperOrigin-RevId: 156211060 --- tensorflow/core/grappler/optimizers/BUILD | 1 + .../grappler/optimizers/meta_optimizer.cc | 2 + tensorflow/core/grappler/utils/BUILD | 24 +++++ .../core/grappler/utils/topological_sort.cc | 64 +++++++++++++ .../core/grappler/utils/topological_sort.h | 30 ++++++ .../grappler/utils/topological_sort_test.cc | 94 +++++++++++++++++++ 6 files changed, 215 insertions(+) create mode 100644 tensorflow/core/grappler/utils/topological_sort.cc create mode 100644 tensorflow/core/grappler/utils/topological_sort.h create mode 100644 tensorflow/core/grappler/utils/topological_sort_test.cc diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 07c7ea1fc1..f88b995c89 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -264,5 +264,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/utils:topological_sort", ], ) diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 23adaab8ed..8bb7800df4 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" #include "tensorflow/core/grappler/optimizers/memory_optimizer.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" +#include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -99,6 +100,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, optimizer->Optimize(cluster, optimized_item, optimized_graph)); } } + TopologicalSort(optimized_graph); // Copy the graph version. *optimized_graph->mutable_versions() = item.graph.versions(); diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index e1a3c574a6..e1db1a8cd2 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -38,3 +38,27 @@ cc_test( "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", ], ) + +cc_library( + name = "topological_sort", + srcs = ["topological_sort.cc"], + hdrs = ["topological_sort.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:utils", + ], +) + +cc_test( + name = "topological_sort_test", + srcs = ["topological_sort_test.cc"], + deps = [ + ":topological_sort", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc new file mode 100644 index 0000000000..131756fc5c --- /dev/null +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/topological_sort.h" +#include +#include +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/utils.h" + +namespace tensorflow { +namespace grappler { + +// Kahn's algorithm is implemented. +// For details, see https://en.wikipedia.org/wiki/Topological_sorting +void TopologicalSort(GraphDef* graph) { + NodeMap node_map(graph); + std::deque ready_nodes; + std::unordered_map ready_inputs; + for (const NodeDef& node : graph->node()) { + if (node.input_size() == 0) { + ready_nodes.push_back(&node); + } + if (node.op() == "Merge") { + ready_inputs[&node] = 0; + for (const auto& input : node.input()) { + if (node_map.GetNode(input)->op() == "NextIteration") { + ready_inputs[&node]++; + } + } + } else { + ready_inputs[&node] = 0; + } + } + GraphDef sorted_graph; + while (!ready_nodes.empty()) { + auto ready_node = ready_nodes.front(); + *sorted_graph.add_node() = *ready_node; + for (const auto& fanout : node_map.GetOutputs(ready_node->name())) { + ready_inputs[fanout]++; + if (ready_inputs[fanout] == fanout->input_size()) { + ready_nodes.push_back(fanout); + } + } + ready_nodes.pop_front(); + } + if (sorted_graph.node_size() == graph->node_size()) { + *graph = sorted_graph; + } +} + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/topological_sort.h b/tensorflow/core/grappler/utils/topological_sort.h new file mode 100644 index 0000000000..d4d8034ef5 --- /dev/null +++ b/tensorflow/core/grappler/utils/topological_sort.h @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ + +#include "tensorflow/core/framework/graph.pb.h" + +namespace tensorflow { +namespace grappler { + +// Sort a graph in topological order. +void TopologicalSort(GraphDef* graph); + +} // namespace grappler +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ diff --git a/tensorflow/core/grappler/utils/topological_sort_test.cc b/tensorflow/core/grappler/utils/topological_sort_test.cc new file mode 100644 index 0000000000..55f66b2734 --- /dev/null +++ b/tensorflow/core/grappler/utils/topological_sort_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class TopologicalSortTest : public ::testing::Test { + protected: + static NodeDef CreateNode(const string& name, + const std::vector& inputs) { + return CreateNode(name, "", inputs); + } + static NodeDef CreateNode(const string& name, const string& op, + const std::vector& inputs) { + NodeDef node; + node.set_name(name); + if (!op.empty()) { + node.set_op(op); + } + for (const string& input : inputs) { + node.add_input(input); + } + return node; + } +}; + +TEST_F(TopologicalSortTest, NoLoop) { + GraphDef graph; + *graph.add_node() = CreateNode("2", {"5"}); + *graph.add_node() = CreateNode("0", {"5", "4"}); + *graph.add_node() = CreateNode("1", {"4", "3"}); + *graph.add_node() = CreateNode("3", {"2"}); + *graph.add_node() = CreateNode("5", {}); + *graph.add_node() = CreateNode("4", {}); + + TopologicalSort(&graph); + std::vector order = {"5", "4", "2", "0", "3", "1"}; + for (int i = 0; i < order.size(); i++) { + EXPECT_EQ(graph.node(i).name(), order[i]); + } +} + +TEST_F(TopologicalSortTest, WithLoop) { + GraphDef graph; + // Create a loop + *graph.add_node() = CreateNode("2", "Merge", {"1", "5"}); + *graph.add_node() = CreateNode("3", "Switch", {"2"}); + *graph.add_node() = CreateNode("4", "Identity", {"3"}); + *graph.add_node() = CreateNode("5", "NextIteration", {"4"}); + *graph.add_node() = CreateNode("1", {}); + + TopologicalSort(&graph); + std::vector order = {"1", "2", "3", "4", "5"}; + for (int i = 0; i < order.size(); i++) { + EXPECT_EQ(graph.node(i).name(), order[i]); + } +} + +TEST_F(TopologicalSortTest, WithIllegalLoop) { + GraphDef graph; + // A loop without Merge and NextIteration is illegal and the original node + // order and graph will be preserved. + *graph.add_node() = CreateNode("2", {"1", "3"}); + *graph.add_node() = CreateNode("3", {"2"}); + *graph.add_node() = CreateNode("1", {}); + + TopologicalSort(&graph); + std::vector order = {"2", "3", "1"}; + for (int i = 0; i < order.size(); i++) { + EXPECT_EQ(graph.node(i).name(), order[i]); + } +} + +} // namespace +} // namespace grappler +} // namespace tensorflow -- GitLab From d0b47889fb3a776dd6cf0a3c8360be9f84e2cbbb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 May 2017 12:18:06 -0700 Subject: [PATCH 681/697] Remove the paper-item/all-imports.html import within tf_graph_info_d3v4/tf-node-info.html import files. PiperOrigin-RevId: 156214018 --- .../components/tf_graph_info_d3v4/tf-node-info.html | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-node-info.html b/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-node-info.html index 0715777370..f1455acaee 100644 --- a/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-node-info.html +++ b/tensorflow/tensorboard/components/tf_graph_info_d3v4/tf-node-info.html @@ -19,7 +19,8 @@ limitations under the License. - + + -- GitLab From 3c5e944a8fa14ade68505ca260a65c830c28e707 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Arrufat?= Date: Tue, 16 May 2017 21:25:24 +0200 Subject: [PATCH 682/697] add pkg-config generation script (#9784) * add pkg-config generation script * add suggestions from asimshankar to generate-pc.sh --- tensorflow/c/generate-pc.sh | 49 +++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100755 tensorflow/c/generate-pc.sh diff --git a/tensorflow/c/generate-pc.sh b/tensorflow/c/generate-pc.sh new file mode 100755 index 0000000000..ea2eed011c --- /dev/null +++ b/tensorflow/c/generate-pc.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash + +TF_PREFIX='/usr/local' + +usage() { + echo "Usage: $0 OPTIONS" + echo -e "-p, --prefix\tset installation prefix (default: /usr/local)" + echo -e "-v, --version\tset TensorFlow version" + echo -e "-h, --help\tdisplay this message" +} + +# read the options +ARGS=`getopt -o p:v:h --long prefix:,version:,help -n $0 -- "$@"` +eval set -- "$ARGS" + +# extract options and their arguments into variables. +while true ; do + case "$1" in + -h|--help) usage ; exit ;; + -p|--prefix) + case "$2" in + "") shift 2 ;; + *) TF_PREFIX=$2 ; shift 2 ;; + esac ;; + -v|--version) + case "$2" in + "") shift 2 ;; + *) TF_VERSION=$2 ; shift 2 ;; + esac ;; + --) shift ; echo "Try '$0 --help' for more information."; exit 1 ;; + *) echo "Internal error! Try '$0 --help' for more information." ; exit 1 ;; + esac +done + +echo "Generating pkgconfig file for TensorFlow $TF_VERSION in $TF_PREFIX" + +cat << EOF > tensorflow.pc +prefix=${TF_PREFIX} +exec_prefix=\${prefix} +libdir=\${exec_prefix}/lib +includedir=\${prefix}/include + +Name: TensorFlow +Version: ${TF_VERSION} +Description: Library for computation using data flow graphs for scalable machine learning +Requires: +Libs: -L\${libdir} -ltensorflow +Cflags: -I\${includedir} +EOF -- GitLab From 97cf8962f6e304f8dfb1f6737720066a43e93172 Mon Sep 17 00:00:00 2001 From: krivard Date: Tue, 16 May 2017 15:27:48 -0400 Subject: [PATCH 683/697] Replaced deprecated op_scope call with name_scope and correct argument order (#9941) --- tensorflow/python/ops/sparse_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 23afb13d4f..9286114277 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -1766,7 +1766,7 @@ def sparse_transpose(sp_input, perm=None, name=None): Raises: TypeError: If `sp_input` is not a `SparseTensor`. """ - with ops.op_scope([sp_input], name, "SparseTranspose") as name: + with ops.name_scope(name, "SparseTranspose", [sp_input]) as name: if perm is None: rank = array_ops.rank(sp_input) perm = (rank - 1) - math_ops.range(0, rank, 1) -- GitLab From da0d883cba8c3b8103e17432bf8398a56a719c40 Mon Sep 17 00:00:00 2001 From: Chris Hoyean Song Date: Wed, 17 May 2017 04:35:46 +0900 Subject: [PATCH 684/697] Add initializer to GRUCell, MultiRNNCell #9600 (#9700) * Add initializer to GRUCell, MultiRNNCell #9600 * change initializer to kernel_initializer add bias_initializer remove initializer from MultiRNNCell --- .../rnn/python/ops/core_rnn_cell_impl.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py index eba2c0d2ac..f62fba2fed 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py @@ -83,12 +83,15 @@ class BasicRNNCell(RNNCell): class GRUCell(RNNCell): """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" - def __init__(self, num_units, input_size=None, activation=tanh, reuse=None): + def __init__(self, num_units, input_size=None, activation=tanh, reuse=None, + kernel_initializer=None, bias_initializer=None): super(GRUCell, self).__init__(_reuse=reuse) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) self._num_units = num_units self._activation = activation + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer @property def state_size(self): @@ -102,10 +105,16 @@ class GRUCell(RNNCell): """Gated recurrent unit (GRU) with nunits cells.""" with vs.variable_scope("gates"): # Reset gate and update gate. # We start with bias of 1.0 to not reset and not update. - value = sigmoid(_linear([inputs, state], 2 * self._num_units, True, 1.0)) + bias_ones = self._bias_initializer + if self._bias_initializer is None: + dtype = [a.dtype for a in [inputs, state]][0] + bias_ones = init_ops.constant_initializer(1.0, dtype=dtype) + value = sigmoid(_linear([inputs, state], 2 * self._num_units, True, + bias_ones, self._kernel_initializer)) r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) with vs.variable_scope("candidate"): - c = self._activation(_linear([inputs, r * state], self._num_units, True)) + c = self._activation(_linear([inputs, r * state], self._num_units, True, + self._bias_initializer, self._kernel_initializer)) new_h = u * state + (1 - u) * c return new_h, new_h @@ -963,14 +972,16 @@ class _SlimRNNCell(RNNCell): return output, state -def _linear(args, output_size, bias, bias_start=0.0): +def _linear(args, output_size, bias, bias_initializer=None, + kernel_initializer=None): """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. Args: args: a 2D Tensor or a list of 2D, batch x n, Tensors. output_size: int, second dimension of W[i]. bias: boolean, whether to add a bias term or not. - bias_start: starting value to initialize the bias; 0 by default. + bias_initializer: starting value to initialize the bias; None by default. + kernel_initializer: starting value to initialize the weight; None by default. Returns: A 2D Tensor with shape [batch x output_size] equal to @@ -1002,7 +1013,8 @@ def _linear(args, output_size, bias, bias_start=0.0): scope = vs.get_variable_scope() with vs.variable_scope(scope) as outer_scope: weights = vs.get_variable( - _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], dtype=dtype) + _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], dtype=dtype, + initializer=kernel_initializer) if len(args) == 1: res = math_ops.matmul(args[0], weights) else: @@ -1011,8 +1023,10 @@ def _linear(args, output_size, bias, bias_start=0.0): return res with vs.variable_scope(outer_scope) as inner_scope: inner_scope.set_partitioner(None) + if bias_initializer is None: + bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) biases = vs.get_variable( _BIAS_VARIABLE_NAME, [output_size], dtype=dtype, - initializer=init_ops.constant_initializer(bias_start, dtype=dtype)) + initializer=bias_initializer) return nn_ops.bias_add(res, biases) -- GitLab From fe41d05e7c8343ed53fc788d6c312792b390f679 Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Tue, 16 May 2017 13:08:04 -0700 Subject: [PATCH 685/697] Go: Keep errors in Go. Helps with #9931 - graph.go: Return a nil Operation if the operation creation failed. This ensures that accidental usage of the nil operation results in a panic/stacktrace purely in Go (nil pointer dereference) instead of a SIGSEGV in the underlying C API. - operation.go: Attempt to help in the developer find the root cause of the nil pointer dereference with a chattier panic message. PiperOrigin-RevId: 156220056 --- tensorflow/go/graph.go | 8 ++++---- tensorflow/go/op/op_test.go | 25 +++++++++++++++++++++++++ tensorflow/go/operation.go | 5 +++++ 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index e65619e80b..46c600eab1 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -185,11 +185,11 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) { return nil, fmt.Errorf("%v (memory will be leaked)", err) } } - op := &Operation{ - c: C.TF_FinishOperation(cdesc, status.c), - g: g, + c := C.TF_FinishOperation(cdesc, status.c) + if err := status.Err(); err != nil { + return nil, err } - return op, status.Err() + return &Operation{c, g}, nil } func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error { diff --git a/tensorflow/go/op/op_test.go b/tensorflow/go/op/op_test.go index 65877dca96..2451ba3606 100644 --- a/tensorflow/go/op/op_test.go +++ b/tensorflow/go/op/op_test.go @@ -19,6 +19,7 @@ limitations under the License. package op import ( + "strings" "testing" tf "github.com/tensorflow/tensorflow/tensorflow/go" @@ -33,3 +34,27 @@ func TestPlaceholder(t *testing.T) { t.Fatal(err) } } + +func TestAddOperationFailure(t *testing.T) { + // Inspired from https://github.com/tensorflow/tensorflow/issues/9931 + s := NewScope() + + resize := ResizeArea(s, Placeholder(s, tf.Float), Const(s, []int64{80, 80})) + if err := s.Err(); err == nil { + t.Fatal("ResizeArea expects an int32 Tensor for size, should fail when an int64 is provided") + } + // And any use of resize should panic with an error message more informative than SIGSEGV + defer func() { + r := recover() + if r == nil { + return + } + s, ok := r.(string) + if ok && strings.Contains(s, "see Scope.Err() for details") { + return + } + t.Errorf("Expected panic string to Scope.Err(), found %T: %q", r, r) + }() + _ = resize.Shape() + t.Errorf("resize.Shape() should have paniced since the underlying Operation was not created") +} diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go index e8f67c4f73..8fcad61f4c 100644 --- a/tensorflow/go/operation.go +++ b/tensorflow/go/operation.go @@ -113,6 +113,11 @@ func (p Output) Shape() Shape { } func (p Output) c() C.TF_Output { + if p.Op == nil { + // Attempt to provide a more useful panic message than "nil + // pointer dereference". + panic("nil-Operation. If the Output was created with a Scope object, see Scope.Err() for details.") + } return C.TF_Output{oper: p.Op.c, index: C.int(p.Index)} } -- GitLab From 70744fb43b75743eeff0b09b862128281a7f3494 Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Tue, 16 May 2017 13:21:25 -0700 Subject: [PATCH 686/697] Go: Some cleanup possible now that TensorFlow 1.1 has been released. PiperOrigin-RevId: 156221562 --- tensorflow/go/lib.go | 10 ---------- tensorflow/go/session.cpp | 26 -------------------------- tensorflow/go/session.go | 2 +- 3 files changed, 1 insertion(+), 37 deletions(-) delete mode 100644 tensorflow/go/session.cpp diff --git a/tensorflow/go/lib.go b/tensorflow/go/lib.go index 551cfa0b01..2800eded60 100644 --- a/tensorflow/go/lib.go +++ b/tensorflow/go/lib.go @@ -18,14 +18,4 @@ package tensorflow // #cgo LDFLAGS: -ltensorflow // #cgo CFLAGS: -I${SRCDIR}/../../ -// -// // TODO(ashankar): Remove this after TensorFlow 1.1 has been released. -// // Till then, the TensorFlow C API binary releases do not contain -// // the TF_DeletePRunHandle symbol. We work around that by -// // implementing the equivalent in session.cpp -// extern void tfDeletePRunHandle(const char*); import "C" - -func deletePRunHandle(h *C.char) { - C.tfDeletePRunHandle(h) -} diff --git a/tensorflow/go/session.cpp b/tensorflow/go/session.cpp deleted file mode 100644 index efa225505b..0000000000 --- a/tensorflow/go/session.cpp +++ /dev/null @@ -1,26 +0,0 @@ -/* -Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// TODO(ashankar): Remove this file when TensorFlow 1.1 is released. -// See lib.go for details. - -extern "C" { -extern void tfDeletePRunHandle(const char* h); -} - -void tfDeletePRunHandle(const char* h) { - delete[] h; -} diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go index 3add412dcd..afa73030b8 100644 --- a/tensorflow/go/session.go +++ b/tensorflow/go/session.go @@ -199,7 +199,7 @@ func (s *Session) NewPartialRun(feeds, fetches []Output, targets []*Operation) ( return nil, err } runtime.SetFinalizer(pr, func(pr *PartialRun) { - deletePRunHandle(pr.handle) + C.TF_DeletePRunHandle(pr.handle) }) return pr, nil } -- GitLab From 28de1790f6a922471c44740fde08995d56a4ee5b Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Tue, 16 May 2017 13:24:52 -0700 Subject: [PATCH 687/697] Disable sdca_estimator_test in tsan. PiperOrigin-RevId: 156222038 --- tensorflow/contrib/linear_optimizer/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD index 472633b5c7..1fde6e5c6c 100644 --- a/tensorflow/contrib/linear_optimizer/BUILD +++ b/tensorflow/contrib/linear_optimizer/BUILD @@ -127,6 +127,7 @@ py_test( name = "sdca_estimator_test", srcs = ["python/sdca_estimator_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":sdca_estimator_py", "//tensorflow/contrib/layers:layers_py", -- GitLab From eae242a4a9824655667c73637cf4498ebc0fe1c5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 May 2017 13:47:10 -0700 Subject: [PATCH 688/697] Internal changes. PiperOrigin-RevId: 156224948 --- tensorflow/python/kernel_tests/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 6689d6a6b4..5d9534a206 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1564,6 +1564,7 @@ cuda_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", ], + shard_count = 4, ) cuda_py_test( -- GitLab From d0936ab91ccdc0e50dc31b0c0bb7711132fcd445 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 May 2017 15:03:38 -0700 Subject: [PATCH 689/697] Fix the issue in which the user choppily moves into a different graph transform configuration after panning/zooming after fitting the graph. The main problem had been that the fit function had been altering the root 'g' element's transform, while the zoom behavior had been instead altering the root 'svg' element's transform. PiperOrigin-RevId: 156236623 --- .../components/tf_graph_common_d3v4/scene.ts | 2 +- .../components/tf_graph_d3v4/tf-graph-scene.html | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow/tensorboard/components/tf_graph_common_d3v4/scene.ts b/tensorflow/tensorboard/components/tf_graph_common_d3v4/scene.ts index 023bc161f5..29f9b446b3 100644 --- a/tensorflow/tensorboard/components/tf_graph_common_d3v4/scene.ts +++ b/tensorflow/tensorboard/components/tf_graph_common_d3v4/scene.ts @@ -155,7 +155,7 @@ module tf.graph.scene { .scale(scale) .translate(params.padding.paddingLeft, params.padding.paddingTop); - d3.select(zoomG) + d3.select(svg) .transition() .duration(500) .call(d3zoom.transform, transform) diff --git a/tensorflow/tensorboard/components/tf_graph_d3v4/tf-graph-scene.html b/tensorflow/tensorboard/components/tf_graph_d3v4/tf-graph-scene.html index 95d9d16f85..10a65f54d5 100644 --- a/tensorflow/tensorboard/components/tf_graph_d3v4/tf-graph-scene.html +++ b/tensorflow/tensorboard/components/tf_graph_d3v4/tf-graph-scene.html @@ -649,12 +649,12 @@ Polymer({ }, /** Keeps track of the starting coordinates of a graph zoom/pan */ _zoomStartCoords: { - type: Array, + type: Object, value: null }, /** Keeps track of the current coordinates of a graph zoom/pan */ - _zoomCoords: { - type: Array, + _zoomTransform: { + type: Object, value: null }, /** Maximum distance of a zoom event for it to be interpreted as a click */ @@ -797,8 +797,8 @@ Polymer({ // is ignored (as this mouse click was part of a zooming, and should // not be used to indicate an actual click on the graph). var dragDistance = Math.sqrt( - Math.pow(this._zoomStartCoords[0] - this._zoomCoords[0], 2) + - Math.pow(this._zoomStartCoords[1] - this._zoomCoords[1], 2)); + Math.pow(this._zoomStartCoords.x - this._zoomTransform.x, 2) + + Math.pow(this._zoomStartCoords.y - this._zoomTransform.y, 2)); if (dragDistance < this._maxZoomDistanceForClick) { this._fireEnableClick(); } else { @@ -808,8 +808,8 @@ Polymer({ this._zoomStartCoords = null; }.bind(this)) .on('zoom', function() { - // Store the coordinates of the zoom event - this._zoomCoords = [d3.event.transform.x, d3.event.transform.y]; + // Store the coordinates of the zoom event. + this._zoomTransform = d3.event.transform; // If this is the first zoom event after a zoom-end, then // store the coordinates as the start coordinates as well, @@ -818,7 +818,7 @@ Polymer({ // event on mouse-down, even if there has been no dragging // done to translate the graph around. if (!this._zoomStartCoords) { - this._zoomStartCoords = this._zoomCoords; + this._zoomStartCoords = this._zoomTransform; this.fire('disable-click'); } this._zoomed = true; -- GitLab From b45ded11d4d9a8f937004782115d2bbb29b97efd Mon Sep 17 00:00:00 2001 From: Alexey Surkov Date: Tue, 16 May 2017 15:19:46 -0700 Subject: [PATCH 690/697] Make the readahead buffer size configurable via an env variable. PiperOrigin-RevId: 156238858 --- tensorflow/core/platform/cloud/gcs_file_system.cc | 13 ++++++++++++- tensorflow/core/platform/cloud/gcs_file_system.h | 3 ++- .../core/platform/cloud/gcs_file_system_test.cc | 9 +++++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 5ee3099673..97e4c207d8 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -52,6 +52,8 @@ constexpr int kGetChildrenDefaultPageSize = 1000; constexpr uint64 kUploadRetryDelayMicros = 1000000L; // The HTTP response code "308 Resume Incomplete". constexpr uint64 HTTP_CODE_RESUME_INCOMPLETE = 308; +// The environment variable that overrides the size of the readahead buffer. +constexpr char kReadaheadBufferSize[] = "GCS_READAHEAD_BUFFER_SIZE_BYTES"; // The file statistics returned by Stat() for directories. const FileStatistics DIRECTORY_STAT(0, 0, true); @@ -585,7 +587,16 @@ class GcsReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { GcsFileSystem::GcsFileSystem() : auth_provider_(new GoogleAuthProvider()), - http_request_factory_(new HttpRequest::Factory()) {} + http_request_factory_(new HttpRequest::Factory()) { + // Apply the sys env override for the readahead buffer size if it's provided. + const char* readahead_buffer_size = std::getenv(kReadaheadBufferSize); + if (readahead_buffer_size) { + uint64 value; + if (strings::safe_strtou64(readahead_buffer_size, &value)) { + read_ahead_bytes_ = value; + } + } +} GcsFileSystem::GcsFileSystem( std::unique_ptr auth_provider, diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index 6a6437f070..18d2de482b 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -74,6 +74,7 @@ class GcsFileSystem : public FileSystem { Status DeleteRecursively(const string& dirname, int64* undeleted_files, int64* undeleted_dirs) override; + size_t get_readahead_buffer_size() const { return read_ahead_bytes_; } private: /// \brief Checks if the bucket exists. Returns OK if the check succeeded. @@ -112,7 +113,7 @@ class GcsFileSystem : public FileSystem { // The number of bytes to read ahead for buffering purposes in the // RandomAccessFile implementation. Defaults to 256Mb. - const size_t read_ahead_bytes_ = 256 * 1024 * 1024; + size_t read_ahead_bytes_ = 256 * 1024 * 1024; // The initial delay for exponential backoffs when retrying failed calls. const int64 initial_retry_delay_usec_ = 1000000L; diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index fc79f3be11..c3a8678fbc 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -1617,5 +1617,14 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) { EXPECT_EQ(1, undeleted_dirs); } +TEST(GcsFileSystemTest, OverrideReadaheadBufferSize) { + GcsFileSystem fs1; + EXPECT_EQ(256 * 1024 * 1024, fs1.get_readahead_buffer_size()); + + setenv("GCS_READAHEAD_BUFFER_SIZE_BYTES", "123456789", 1); + GcsFileSystem fs2; + EXPECT_EQ(123456789L, fs2.get_readahead_buffer_size()); +} + } // namespace } // namespace tensorflow -- GitLab From 7ca252469e86101eb145cc1412b9ca3bc3181684 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 May 2017 15:20:58 -0700 Subject: [PATCH 691/697] Handle R0 DynamicUpdateSlice in algebraic simplifier. PiperOrigin-RevId: 156239012 --- .../compiler/xla/service/algebraic_simplifier.cc | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index f240b1ebe9..3f888b4c2e 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1,3 +1,4 @@ + /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -167,6 +168,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleReverse(HloInstruction* reverse, HloInstruction* operand) override; Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices) override; Status HandleTranspose(HloInstruction* transpose) override; @@ -1020,6 +1025,16 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice, return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice, HloInstruction* operand, + HloInstruction* update, HloInstruction* start_indices) { + // DynamicUpdateSlice on a scalar just passes through the update argument. + if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { + return ReplaceInstruction(dynamic_update_slice, update); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function) { -- GitLab From 16a8a6d7d834e4645e007b2e3cfbe26511bf133a Mon Sep 17 00:00:00 2001 From: Kay Zhu Date: Tue, 16 May 2017 15:25:01 -0700 Subject: [PATCH 692/697] [TF] A couple fixes for Mod: In TF: - fix incorrect doc string for TruncateMod. - remove outdated TODO. PiperOrigin-RevId: 156239535 --- tensorflow/core/ops/math_ops.cc | 13 +++++++------ tensorflow/python/ops/math_ops.py | 2 -- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 79a8be0048..28c4ec643e 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -595,7 +595,9 @@ REGISTER_OP("Mod") .Attr("T: {int32, int64, float, double}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( -Returns element-wise remainder of division. +Returns element-wise remainder of division. This emulates C semantics in that +the result here is consistent with a truncating divide. E.g. `truncate(x / y) * +y + truncate_mod(x, y) = x`. *NOTE*: `Mod` supports broadcasting. More about broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) @@ -623,12 +625,11 @@ REGISTER_OP("TruncateMod") .Attr("T: {int32, int64, float, double}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( -Returns element-wise remainder of division. This emulates C semantics where +Returns element-wise remainder of division. This emulates C semantics in that +the result here is consistent with a truncating divide. E.g. `truncate(x / y) * +y + truncate_mod(x, y) = x`. -true, this follows C semantics in that the result here is consistent -with a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`. - -*NOTE*: `Mod` supports broadcasting. More about broadcasting +*NOTE*: `TruncateMod` supports broadcasting. More about broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) )doc"); diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 158016ff37..1555d19395 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1082,8 +1082,6 @@ _OverrideBinaryOperatorHelper(_mul_dispatch, "mul") _OverrideBinaryOperatorHelper(_div_python2, "div") _OverrideBinaryOperatorHelper(_truediv_python3, "truediv") _OverrideBinaryOperatorHelper(floordiv, "floordiv") -# TODO(aselle): Switch mod to floor_mod when ready -# _OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod") _OverrideBinaryOperatorHelper(gen_math_ops._floor_mod, "mod") _OverrideBinaryOperatorHelper(pow, "pow") -- GitLab From 77583357ed95122ed85915556f639f0cbefa8251 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 May 2017 15:37:49 -0700 Subject: [PATCH 693/697] Automated g4 rollback of changelist 156170472 PiperOrigin-RevId: 156240877 --- tensorflow/BUILD | 2 +- tensorflow/tools/docs/BUILD | 15 --------------- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 0c2fdab9b8..503ad79a38 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -458,7 +458,7 @@ filegroup( filegroup( name = "docs_src", - srcs = glob(["docs_src/**/*.md"]), + data = glob(["docs_src/**/*.md"]), ) # ------------------------------------------- diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 0c0239352b..8e27b133c2 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -151,18 +151,3 @@ filegroup( ], ), ) - -genrule( - name = "python_docs", - srcs = ["//tensorflow:docs_src"], - outs = ["python_docs.tgz"], - cmd = "STARTDIR=$$(pwd); " + - "TMP=$$(mktemp -d $${TMPDIR:-/tmp}/docs.XXXXXXXXXX); " + - "cd $$TMP; " + - "$$STARTDIR/$(location :generate) --src_dir=$$STARTDIR/third_party/tensorflow/docs_src --output_dir=docs_out; " + - "tar -czf $$STARTDIR/$@ docs_out; " + - "cd $$STARTDIR; " + - "rm -rf $$TMP; ", - tools = [":generate"], - visibility = ["//visibility:public"], -) -- GitLab From ed5d05d8b53425ef98aad129a60143a5011a4288 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 May 2017 16:03:52 -0700 Subject: [PATCH 694/697] Update ops-related pbtxt files. PiperOrigin-RevId: 156244273 --- tensorflow/core/ops/ops.pbtxt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 0ff321aed5..cd43881a46 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -12001,8 +12001,8 @@ op { } } } - summary: "Returns element-wise remainder of division." - description: "*NOTE*: `Mod` supports broadcasting. More about broadcasting\n[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)" + summary: "Returns element-wise remainder of division. This emulates C semantics in that" + description: "the result here is consistent with a truncating divide. E.g. `truncate(x / y) *\ny + truncate_mod(x, y) = x`.\n\n*NOTE*: `Mod` supports broadcasting. More about broadcasting\n[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)" } op { name: "Mul" @@ -26637,8 +26637,8 @@ op { } } } - summary: "Returns element-wise remainder of division. This emulates C semantics where" - description: "true, this follows C semantics in that the result here is consistent\nwith a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`.\n\n*NOTE*: `Mod` supports broadcasting. More about broadcasting\n[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)" + summary: "Returns element-wise remainder of division. This emulates C semantics in that" + description: "the result here is consistent with a truncating divide. E.g. `truncate(x / y) *\ny + truncate_mod(x, y) = x`.\n\n*NOTE*: `TruncateMod` supports broadcasting. More about broadcasting\n[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)" } op { name: "TruncatedNormal" -- GitLab From 749e5cc18381f7a5ec174673f76e20aead8529c6 Mon Sep 17 00:00:00 2001 From: Geoffrey Irving Date: Tue, 16 May 2017 16:08:20 -0700 Subject: [PATCH 695/697] Reduce direct references to NodeDef in favor of Node and AttrSlice This is one step towards replacing in-memory use of NodeDef with a customized NodeInfo class. There are still quite a few Node::def() references, but far fewer than before. Those remaining require more work, either because they are part of kernel registration (which is a bunch of functions), copy and modify the NodeDef, etc. Follow-on CLs will remove more. RELNOTES: n/a PiperOrigin-RevId: 156244933 --- tensorflow/c/c_api.cc | 63 +++++----- tensorflow/cc/framework/cc_op_gen.cc | 9 +- tensorflow/cc/framework/cc_ops_test.cc | 17 ++- tensorflow/cc/framework/scope.cc | 4 +- tensorflow/cc/gradients/array_grad.cc | 28 +++-- tensorflow/cc/gradients/math_grad.cc | 8 +- tensorflow/cc/ops/const_op_test.cc | 8 +- tensorflow/compiler/aot/compile.cc | 12 +- .../compiler/jit/build_xla_launch_ops_pass.cc | 9 +- .../jit/encapsulate_subgraphs_pass.cc | 10 +- .../compiler/jit/graph_to_functiondef.cc | 14 +-- .../jit/kernels/xla_device_launch_op.cc | 4 +- .../jit/kernels/xla_local_launch_op.cc | 2 +- .../compiler/jit/mark_for_compilation_pass.cc | 42 ++++--- .../jit/mark_for_compilation_pass_test.cc | 2 +- .../compiler/jit/xla_compilation_cache.cc | 2 +- tensorflow/compiler/tf2xla/const_analysis.cc | 6 +- .../compiler/tf2xla/kernels/function_ops.cc | 3 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 7 +- tensorflow/core/BUILD | 7 +- .../common_runtime/constant_folding_test.cc | 8 +- tensorflow/core/common_runtime/executor.cc | 56 +++++---- tensorflow/core/common_runtime/function.cc | 37 +++--- .../core/common_runtime/function_test.cc | 40 +++++-- .../kernel_benchmark_testlib.cc | 8 +- .../parallel_concat_optimizer.cc | 4 +- .../resource_variable_read_optimizer.cc | 2 +- .../core/common_runtime/simple_placer.cc | 34 +++--- tensorflow/core/debug/debug_graph_utils.cc | 23 ++-- .../distributed_runtime/master_session.cc | 12 +- tensorflow/core/framework/function.cc | 102 +++++++--------- tensorflow/core/framework/function.h | 49 ++++---- tensorflow/core/framework/function_test.cc | 110 +++++++++++------- tensorflow/core/framework/node_def_util.cc | 92 +++++++++++---- tensorflow/core/framework/node_def_util.h | 39 ++++++- tensorflow/core/framework/op_kernel.cc | 23 ++-- tensorflow/core/framework/shape_inference.h | 9 +- tensorflow/core/graph/control_flow.cc | 2 +- tensorflow/core/graph/graph.cc | 4 +- tensorflow/core/graph/graph.h | 14 +++ tensorflow/core/graph/graph_constructor.cc | 4 +- .../core/graph/graph_constructor_test.cc | 66 +++++------ tensorflow/core/graph/graph_partition.cc | 8 +- tensorflow/core/graph/graph_test.cc | 10 +- tensorflow/core/graph/optimizer_cse.cc | 37 ++---- tensorflow/core/graph/quantize_training.cc | 6 +- .../core/graph/quantize_training_test.cc | 12 +- tensorflow/core/graph/subgraph.cc | 6 +- tensorflow/core/graph/subgraph_test.cc | 4 +- .../core/grappler/costs/graph_properties.cc | 4 +- tensorflow/core/kernels/function_ops.cc | 16 +-- .../core/kernels/hexagon/graph_transferer.cc | 19 +-- .../remote_fused_graph_execute_utils.cc | 6 +- .../remote_fused_graph_execute_utils.h | 2 +- 54 files changed, 605 insertions(+), 520 deletions(-) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index bbc18569ad..f4775783f9 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -738,8 +738,7 @@ tensorflow::string OutputName(const TF_Output& output) { const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, const char* attr_name, TF_Status* status) { - const tensorflow::AttrValue* attr = - tensorflow::AttrSlice(oper->node.def()).Find(attr_name); + const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); if (attr == nullptr) { status->status = InvalidArgument("Operation has no attr named '", attr_name, "'."); @@ -1135,7 +1134,7 @@ const char* TF_OperationOpType(TF_Operation* oper) { } const char* TF_OperationDevice(TF_Operation* oper) { - return oper->node.def().device().c_str(); + return oper->node.requested_device().c_str(); } int TF_OperationNumOutputs(TF_Operation* oper) { @@ -1150,8 +1149,8 @@ TF_DataType TF_OperationOutputType(TF_Output oper_out) { int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, TF_Status* status) { NameRangeMap name_ranges; - status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(), - nullptr, &name_ranges); + status->status = + NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges); if (!status->status.ok()) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { @@ -1172,8 +1171,8 @@ TF_DataType TF_OperationInputType(TF_Input oper_in) { int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, TF_Status* status) { NameRangeMap name_ranges; - status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(), - &name_ranges, nullptr); + status->status = + NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr); if (!status->status.ok()) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { @@ -1411,26 +1410,27 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, } } -#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ - void func(TF_Operation* oper, const char* attr_name, c_type* value, \ - TF_Status* status) { \ - cpp_type v; \ - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &v); \ - *value = static_cast(v); \ - } \ - void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ - int max_values, TF_Status* status) { \ - const auto* attr = GetAttrValue(oper, attr_name, status); \ - if (!status->status.ok()) return; \ - if (attr->value_case() != tensorflow::AttrValue::kList) { \ - status->status = \ - InvalidArgument("Value for '", attr_name, "' is not a list."); \ - return; \ - } \ - const auto len = std::min(max_values, attr->list().list_field##_size()); \ - for (int i = 0; i < len; ++i) { \ - values[i] = static_cast(attr->list().list_field(i)); \ - } \ +#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ + void func(TF_Operation* oper, const char* attr_name, c_type* value, \ + TF_Status* status) { \ + cpp_type v; \ + status->status = \ + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \ + *value = static_cast(v); \ + } \ + void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ + int max_values, TF_Status* status) { \ + const auto* attr = GetAttrValue(oper, attr_name, status); \ + if (!status->status.ok()) return; \ + if (attr->value_case() != tensorflow::AttrValue::kList) { \ + status->status = \ + InvalidArgument("Value for '", attr_name, "' is not a list."); \ + return; \ + } \ + const auto len = std::min(max_values, attr->list().list_field##_size()); \ + for (int i = 0; i < len; ++i) { \ + values[i] = static_cast(attr->list().list_field(i)); \ + } \ } DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i); DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f); @@ -1441,7 +1441,8 @@ DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type); void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, int64_t* value, int num_dims, TF_Status* status) { PartialTensorShape shape; - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shape); + status->status = + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape); if (!status->status.ok()) return; auto len = std::min(shape.dims(), num_dims); for (int i = 0; i < len; ++i) { @@ -1455,7 +1456,7 @@ void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, int storage_size, TF_Status* status) { std::vector shapes; status->status = - tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shapes); + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes); if (!status->status.ok()) return; auto len = std::min(static_cast(shapes.size()), max_values); int64_t* p = storage; @@ -1522,7 +1523,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, TF_Tensor** value, TF_Status* status) { *value = nullptr; Tensor t; - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &t); + status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); if (!status->status.ok()) return; *value = new TF_Tensor{static_cast(t.dtype()), t.shape(), tensorflow::TensorCApi::Buffer(t)}; @@ -1533,7 +1534,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, TF_Tensor** values, int max_values, TF_Status* status) { std::vector ts; - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &ts); + status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts); if (!status->status.ok()) return; const auto len = std::min(max_values, static_cast(ts.size())); for (int i = 0; i < len; ++i) { diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 799492a4eb..71aa986f91 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -740,11 +740,10 @@ void OpInfo::GetOutput(string* out) const { return; } strings::StrAppend(out, " ::tensorflow::NameRangeMap _outputs_range;\n"); - strings::StrAppend( - out, - " ::tensorflow::Status _status_ = " - "::tensorflow::NameRangesForNode(ret->def(), ret->op_def(), " - "nullptr, &_outputs_range);\n"); + strings::StrAppend(out, + " ::tensorflow::Status _status_ = " + "::tensorflow::NameRangesForNode(*ret, ret->op_def(), " + "nullptr, &_outputs_range);\n"); strings::StrAppend(out, " if (!_status_.ok()) {\n", " ", scope_str, ".UpdateStatus(_status_);\n", " return;\n"); strings::StrAppend(out, " }\n\n"); diff --git a/tensorflow/cc/framework/cc_ops_test.cc b/tensorflow/cc/framework/cc_ops_test.cc index 92c97d107d..5da23036ea 100644 --- a/tensorflow/cc/framework/cc_ops_test.cc +++ b/tensorflow/cc/framework/cc_ops_test.cc @@ -35,8 +35,8 @@ Output Linear(const Scope& scope, Input x, Input w, Input b) { void GetColocationConstraints(const Output& tensor, std::vector* constraints) { constraints->clear(); - TF_EXPECT_OK( - GetNodeAttr(tensor.op().node()->def(), kColocationAttrName, constraints)); + TF_EXPECT_OK(GetNodeAttr(tensor.op().node()->attrs(), kColocationAttrName, + constraints)); } } // namespace @@ -159,11 +159,11 @@ TEST(CCOpTest, KernelLabel) { Scope root = Scope::NewRootScope(); auto add = Add(root.WithKernelLabel("AddWithKernelLabel"), 1.0f, 2.0f); TF_EXPECT_OK(root.status()); - const auto& attrs = add.z.op().node()->def().attr(); - ASSERT_TRUE(attrs.find("_kernel") != attrs.end()); - auto kernel_attr = attrs.find("_kernel")->second; - TF_EXPECT_OK(AttrValueHasType(kernel_attr, "string")); - EXPECT_EQ(kernel_attr.s(), "AddWithKernelLabel"); + AttrSlice attrs = add.z.op().node()->attrs(); + const auto* kernel_attr = attrs.Find("_kernel"); + ASSERT_TRUE(kernel_attr); + TF_EXPECT_OK(AttrValueHasType(*kernel_attr, "string")); + EXPECT_EQ(kernel_attr->s(), "AddWithKernelLabel"); } TEST(CCOpTest, ColocateWith) { @@ -190,8 +190,7 @@ TEST(CCOpTest, ColocateWith) { Scope with_colocate = root.ColocateWith(c3).ColocateWith(c4); auto c6 = Const(with_colocate.WithOpName("c6").ClearColocation(), 7); - const auto& attrs = c6.op().node()->def().attr(); - EXPECT_TRUE(attrs.find("_class") == attrs.end()); + EXPECT_FALSE(c6.op().node()->attrs().Find("_class")); } TEST(CCOpTest, TemplatedConst) { diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 8b7fc1406f..32c0822de6 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -271,9 +271,9 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate, std::unordered_set Scope::Impl::GetColocationConstraints( const Operation& colocate_with_op) const { std::unordered_set current_constraints(colocation_constraints_); - const NodeDef& node_def = colocate_with_op.node()->def(); + const AttrSlice attrs = colocate_with_op.node()->attrs(); std::vector node_constraints; - if (GetNodeAttr(node_def, kColocationAttrName, &node_constraints).ok()) { + if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) { for (const string& entry : node_constraints) { StringPiece s(entry); if (s.Consume(kColocationGroupPrefix)) { diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 26abd2438e..37f07e71a0 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -43,9 +43,9 @@ Status PackGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int N; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "N", &N)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N)); int axis; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); grad_outputs->reserve(N); auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis)); @@ -60,7 +60,7 @@ Status UnpackGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int axis; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis))); return scope.status(); } @@ -162,7 +162,7 @@ Status CheckNumericsGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string message; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "message", &message)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message)); string err_msg = strings::StrCat( "Not a number (NaN) or infinity (Inf) values detected in gradient. ", message); @@ -215,9 +215,9 @@ Status ReverseSequenceGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { auto seq_lengths = op.input(1); int batch_dim; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim)); int seq_dim; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim)); grad_outputs->push_back( ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim, ReverseSequence::BatchDim(batch_dim))); @@ -267,7 +267,8 @@ Status SpaceToBatchGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back( BatchToSpace(scope, grad_inputs[0], op.input(1), block_size)); grad_outputs->push_back(NoGradient()); @@ -290,7 +291,8 @@ Status BatchToSpaceGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back( SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size)); grad_outputs->push_back(NoGradient()); @@ -313,7 +315,8 @@ Status SpaceToDepthGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size)); return scope.status(); } @@ -323,7 +326,8 @@ Status DepthToSpaceGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size)); return scope.status(); } @@ -333,7 +337,7 @@ Status MirrorPadGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string mode; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad( scope, grad_inputs[0], op.input(1), mode)); grad_outputs->push_back(NoGradient()); @@ -346,7 +350,7 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string mode; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode)); grad_outputs->push_back(NoGradient()); return scope.status(); diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 5a2c6d11fb..8c1a01f518 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -350,7 +350,7 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op, const string& attr_adj_x, const string& attr_adj_y, std::vector* grad_outputs) { DataType dtype; - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype)); if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { return errors::Unimplemented( "MatMul gradient for complex data type is not supported yet."); @@ -358,8 +358,10 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op, bool ta; bool tb; - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta)); - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb)); if (!ta && !tb) { return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc index 5a4770f879..3184edeb33 100644 --- a/tensorflow/cc/ops/const_op_test.cc +++ b/tensorflow/cc/ops/const_op_test.cc @@ -28,9 +28,9 @@ void ExpectNodeEqual(const Node* n, gtl::ArraySlice values, TensorShape shape) { EXPECT_TRUE(n->IsConstant()); Tensor tensor; - TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor)); DataType dtype; - TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype)); EXPECT_EQ(tensor.dtype(), dtype); test::ExpectTensorEqual(tensor, test::AsTensor(values, shape)); } @@ -39,9 +39,9 @@ void ExpectTypeAndShape(const Node* n, DataType expected_dtype, TensorShape expected_shape) { EXPECT_TRUE(n->IsConstant()); Tensor tensor; - TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor)); DataType dtype; - TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype)); EXPECT_EQ(dtype, expected_dtype); EXPECT_EQ(expected_shape, TensorShape(tensor.shape())); } diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 162e719ade..0c7b97b01f 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -203,14 +203,14 @@ Status RewriteAndPruneGraph(Graph* graph, const Config& config, for (const Node* n : graph->nodes()) { if (n->type_string() == kArgOp) { string feed_id; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id)); if (missing_feeds.erase(feed_id) == 0) { return errors::Aborted(kArgOp, " node found with unknown feed id: ", feed_id); } } else if (n->type_string() == kRetvalOp) { string fetch_id; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id)); if (missing_fetches.erase(fetch_id) == 0) { return errors::Aborted(kRetvalOp, " node found with unknown fetch id: ", fetch_id); @@ -234,7 +234,7 @@ Status CollectArgNodes(const Graph& graph, std::vector* arg_nodes) { for (Node* n : graph.nodes()) { if (n->type_string() == kArgOp) { int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); auto insert_result = indexed_arg_nodes.insert({index, n}); if (!insert_result.second) { const Node* dup = insert_result.first->second; @@ -264,9 +264,9 @@ Status CreateXlaArgs(const Graph& graph, for (const Node* node : arg_nodes) { XlaCompiler::Argument arg; arg.kind = XlaCompiler::Argument::kParameter; - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &arg.type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kShapeAttr, &arg.shape)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kDebugNameAttr, &arg.name)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } return Status::OK(); diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc index abb68f73d7..48eed7fce0 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc @@ -66,9 +66,9 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { int num_constant_args, num_resource_args; TF_RETURN_IF_ERROR( - GetNodeAttr(node->def(), kXlaNumConstantArgsAttr, &num_constant_args)); + GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args)); TF_RETURN_IF_ERROR( - GetNodeAttr(node->def(), kXlaNumResourceArgsAttr, &num_resource_args)); + GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args)); if (num_constant_args < 0 || num_resource_args < 0 || num_constant_args + num_resource_args > node->num_inputs()) { @@ -88,7 +88,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { Node* launch_node; TF_RETURN_IF_ERROR(BuildLaunchNode( graph->NewName(node->name()), node->type_string(), node->def().attr(), - node->def().device(), const_dtypes, num_resource_args, arg_dtypes, + node->requested_device(), const_dtypes, num_resource_args, arg_dtypes, node->output_types(), graph, &launch_node)); launch_node->set_assigned_device_name(node->assigned_device_name()); @@ -173,7 +173,8 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, FunctionLibraryRuntime::Handle handle; // If ndef is not instantiable, e.g., the function does not exist, // simply bail out. - TF_RETURN_IF_ERROR(flr->Instantiate(ndef.op(), ndef.attr(), &handle)); + TF_RETURN_IF_ERROR( + flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); const FunctionBody* fbody = flr->GetFunctionBody(handle); CHECK(fbody); // Can't be nullptr since we just instantiated it. std::vector const_args(fbody->arg_types.size()); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index decea97fc8..88ec45f8d8 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -165,7 +165,7 @@ static const char* const kRetValOp = "_Retval"; // none. string Encapsulator::GetFunctionNameAttr(Node const* node) const { string attr; - if (!GetNodeAttr(node->def(), group_attribute_, &attr).ok()) { + if (!GetNodeAttr(node->attrs(), group_attribute_, &attr).ok()) { attr.clear(); } return attr; @@ -195,7 +195,7 @@ Status Encapsulator::SplitIntoSubgraphs() { // Check the device matches any existing device. string device = node->assigned_device_name().empty() - ? node->def().device() + ? node->requested_device() : node->assigned_device_name(); if (subgraph.device.empty()) { @@ -593,7 +593,7 @@ static Status GetArgTypes(const Graph& graph, DataTypeVector* types) { for (Node* n : graph.nodes()) { if (n->type_string() == kArgOp) { int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); if (index < 0 || index >= types->size()) { return errors::InvalidArgument("Invalid argument number"); } @@ -610,7 +610,7 @@ static Status RenumberArguments(Graph* graph, for (Node* n : graph->nodes()) { if (n->type_string() == kArgOp) { int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); if (index < 0 || index >= permutation.size()) { return errors::InvalidArgument("Invalid argument number"); } @@ -713,7 +713,7 @@ Status EncapsulateSubgraphsPass::Run( bool IsXlaCompiledKernel(const Node& node) { bool is_compiled = false; bool has_compilation_attr = - GetNodeAttr(node.def(), kXlaCompiledKernelAttr, &is_compiled).ok() && + GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() && is_compiled; return has_compilation_attr ? is_compiled : false; } diff --git a/tensorflow/compiler/jit/graph_to_functiondef.cc b/tensorflow/compiler/jit/graph_to_functiondef.cc index 88e292a2c1..83c2338500 100644 --- a/tensorflow/compiler/jit/graph_to_functiondef.cc +++ b/tensorflow/compiler/jit/graph_to_functiondef.cc @@ -126,8 +126,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, if (node->type_string() == kArgOp) { int index; DataType type; - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index)); while (fdef->signature().input_arg_size() <= index) { fdef->mutable_signature()->add_input_arg(); } @@ -143,8 +143,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, if (node->type_string() == kRetValOp) { int index; DataType type; - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index)); while (fdef->signature().output_arg_size() <= index) { fdef->mutable_signature()->add_output_arg(); } @@ -161,7 +161,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, } NodeDef* node_def = fdef->add_node_def(); - node_def->CopyFrom(node->def()); + *node_def = node->def(); node_def->set_name(node_names.Uniquify(node->name())); // Reset input names based on graph rather than the NodeDef. @@ -203,8 +203,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, // Populate tensor_renaming. NameRangeMap output_ranges; - TF_RETURN_IF_ERROR(NameRangesForNode(node->def(), node->op_def(), nullptr, - &output_ranges)); + TF_RETURN_IF_ERROR( + NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges)); for (const auto& output : output_ranges) { for (int i = output.second.first; i < output.second.second; ++i) { const string tensor_name = strings::StrCat( diff --git a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc index cb6c14901b..29c5ff7242 100644 --- a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc @@ -55,7 +55,7 @@ XlaDeviceLaunchOp::XlaDeviceLaunchOp(OpKernelConstruction* ctx) OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); function_ = *func; VLOG(1) << "XlaDeviceLaunch created function=" - << Canonicalize(function_.name(), function_.attr()); + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); DataTypeVector constant_types; OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); num_constant_args_ = constant_types.size(); @@ -81,7 +81,7 @@ std::vector SnapshotResourceVariables(OpKernelContext* ctx, void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XlaDeviceLaunch::Compute " - << Canonicalize(function_.name(), function_.attr()); + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. ResourceMgr* rm = ctx->resource_manager(); diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc index a58abcbdff..40acc0d81d 100644 --- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc @@ -186,7 +186,7 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XlaLocalLaunchOp::Compute " - << Canonicalize(function_.name(), function_.attr()); + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. ResourceMgr* rm = ctx->resource_manager(); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 711b986d7f..73c4e80551 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -56,18 +56,18 @@ bool IsCompilableCall(const NodeDef& call_def, const DeviceType& jit_device_type, int depth, FunctionLibraryRuntime* lib_runtime); -// Tests whether 'while_def' is a completely compilable loop. +// Tests whether 'while_node' is a completely compilable loop. // Every operator in the condition and body functions must be compilable for a // while loop to be compilable. -bool IsCompilableWhile(const NodeDef& while_def, +bool IsCompilableWhile(const Node& while_node, const DeviceType& jit_device_type, int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Loop marking: " << while_def.op(); + VLOG(2) << "Loop marking: " << while_node.type_string(); const NameAttrList* name_attr; NodeDef call; Status status; - status = GetNodeAttr(while_def, "cond", &name_attr); + status = GetNodeAttr(while_node.attrs(), "cond", &name_attr); if (!status.ok()) { VLOG(2) << "Missing 'cond' attribute on While node."; return false; @@ -80,7 +80,7 @@ bool IsCompilableWhile(const NodeDef& while_def, VLOG(2) << "Can't compile loop condition: " << cond_func; return false; } - status = GetNodeAttr(while_def, "body", &name_attr); + status = GetNodeAttr(while_node.attrs(), "body", &name_attr); if (!status.ok()) { VLOG(2) << "Missing 'body' attribute on While node."; return false; @@ -112,7 +112,7 @@ bool IsCompilableCall(const NodeDef& call_def, FunctionLibraryRuntime::Handle handle; Status status = - lib_runtime->Instantiate(call_def.op(), call_def.attr(), &handle); + lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle); if (!status.ok()) { VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status; return false; @@ -134,11 +134,11 @@ bool IsCompilableCall(const NodeDef& call_def, for (Node* node : fbody->graph->nodes()) { if (node->IsSource() || node->IsSink()) continue; - if (node->def().op() == "_Arg" || node->def().op() == "_Retval") continue; - if (node->def().op() == "While") { + if (node->type_string() == "_Arg" || node->type_string() == "_Retval") + continue; + if (node->type_string() == "While") { // Handle functional While loop (not in open source build). - return IsCompilableWhile(node->def(), jit_device_type, depth + 1, - lib_runtime); + return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime); } if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, depth + 1, @@ -192,17 +192,16 @@ Status FindCompilationCandidates( if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) { VLOG(2) << "Compilation rejected node: unsupported op " << node->name() - << ": " << node->def().op(); + << ": " << node->type_string(); continue; } if (!registration->compile_resource_ops && HasResourceArgument(*node)) { VLOG(2) << "Compilation rejected node: resource argument " << node->name() - << ": " << node->def().op(); + << ": " << node->type_string(); continue; } - if (node->def().op() == "While" && - !IsCompilableWhile(node->def(), jit_device_type, 0, - lib_runtime.get())) { + if (node->type_string() == "While" && + !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime.get())) { continue; } candidates->insert(node); @@ -319,10 +318,10 @@ Status MarkForCompilationPass::Run( // If there is a _XlaCompile annotation, use its value. bool compile = false; - Status status = GetNodeAttr(node->def(), kXlaCompileAttr, &compile); + Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); if (status.ok()) return compile; - status = fld->GetAttr(node->def(), kXlaCompileAttr, &compile); + status = fld->GetAttr(*node, kXlaCompileAttr, &compile); if (status.ok()) return compile; // Otherwise use the value of global_jit_level. @@ -485,8 +484,8 @@ Status MarkForCompilationPass::RunImpl( // all nodes marked with _XlaCompile=true to also have a // _XlaScope property set (and raise an error otherwise); but // for now we don't do this. - if (GetNodeAttr(node_from->def(), kXlaScopeAttr, &from_scope).ok() && - GetNodeAttr(node_to->def(), kXlaScopeAttr, &to_scope).ok() && + if (GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() && + GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() && from_scope != to_scope) { continue; } @@ -541,10 +540,9 @@ Status MarkForCompilationPass::RunImpl( // Compile if the user marked this node _XlaCompile=true bool compile_attr = false; bool marked_for_compilation = false; - if (GetNodeAttr(n->def(), kXlaCompileAttr, &compile_attr).ok()) { + if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) { marked_for_compilation = compile_attr; - } else if (options.flib_def - ->GetAttr(n->def(), kXlaCompileAttr, &compile_attr) + } else if (options.flib_def->GetAttr(*n, kXlaCompileAttr, &compile_attr) .ok()) { marked_for_compilation = compile_attr; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 91e4a2b41c..9f30e12e0e 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -57,7 +57,7 @@ std::unordered_map GetClusters(const Graph& graph) { std::unordered_map ids; for (Node* node : graph.nodes()) { string cluster; - if (GetNodeAttr(node->def(), kXlaClusterAttr, &cluster).ok()) { + if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) { CHECK(!cluster.empty()); ids[node->name()] = cluster; } diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 82af304169..63ca77f9a9 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -95,7 +95,7 @@ Status XlaCompilationCache::BuildSignature( const NameAttrList& function, int num_constant_args, const std::vector& variable_args, OpKernelContext* ctx, Signature* signature) { - signature->name = Canonicalize(function.name(), function.attr()); + signature->name = Canonicalize(function.name(), AttrSlice(&function.attr())); signature->arg_values.resize(num_constant_args); signature->arg_types.reserve(ctx->num_inputs() - num_constant_args); diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 44ff13ca34..4adc17b838 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -108,7 +108,7 @@ Status BackwardsConstAnalysis(const Graph& g, if (must_be_const.find(node) != must_be_const.end()) { if (node->type_string() == "_Arg") { int index; - status = GetNodeAttr(node->def(), "index", &index); + status = GetNodeAttr(node->attrs(), "index", &index); if (!status.ok()) return; compile_time_const_args->at(index) = true; return; @@ -124,8 +124,8 @@ Status BackwardsConstAnalysis(const Graph& g, if (range.first == range.second) return; NameRangeMap input_name_ranges; - status = NameRangesForNode(node->def(), node->op_def(), &input_name_ranges, - nullptr); + status = + NameRangesForNode(*node, node->op_def(), &input_name_ranges, nullptr); if (!status.ok()) return; for (auto it = range.first; it != range.second; ++it) { diff --git a/tensorflow/compiler/tf2xla/kernels/function_ops.cc b/tensorflow/compiler/tf2xla/kernels/function_ops.cc index d718f98545..8dacb6627b 100644 --- a/tensorflow/compiler/tf2xla/kernels/function_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/function_ops.cc @@ -68,7 +68,8 @@ class SymbolicGradientOp : public AsyncOpKernel { done); OP_REQUIRES_OK_ASYNC( - ctx, lib->Instantiate(kGradientOp, def().attr(), &handle_), done); + ctx, lib->Instantiate(kGradientOp, AttrSlice(&def().attr()), &handle_), + done); FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index a8034a2ec6..d4a917671b 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -106,7 +106,8 @@ Status XlaCompiler::CompileFunction( const XlaCompiler::CompileOptions& options, const NameAttrList& function, const std::vector& args, XlaCompiler::CompilationResult* result) { - const string function_id = Canonicalize(function.name(), function.attr()); + const string function_id = + Canonicalize(function.name(), AttrSlice(&function.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; auto it = cache_.find({function_id, args}); @@ -116,8 +117,8 @@ Status XlaCompiler::CompileFunction( } FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR( - flib_runtime_->Instantiate(function.name(), function.attr(), &handle)); + TF_RETURN_IF_ERROR(flib_runtime_->Instantiate( + function.name(), AttrSlice(&function.attr()), &handle)); const FunctionBody* fbody = flib_runtime_->GetFunctionBody(handle); CHECK(fbody); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 6056d50587..0006aaa0b5 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1408,6 +1408,11 @@ tf_cuda_library( "framework/**/*.cc", "util/**/*.h", "util/**/*.cc", + ] + [ + "graph/edgeset.h", + "graph/edgeset.cc", + "graph/graph.h", + "graph/graph.cc", ], exclude = [ "**/*test*", @@ -1548,8 +1553,6 @@ tf_cuda_library( "graph/colors.cc", "graph/control_flow.cc", "graph/costmodel.cc", - "graph/edgeset.cc", - "graph/graph.cc", "graph/graph_constructor.cc", "graph/graph_def_builder.cc", "graph/graph_partition.cc", diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index 9cc761a771..4a8560960e 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -48,9 +48,9 @@ class ConstantFoldingTest : public ::testing::Test { TensorShape shape) { EXPECT_TRUE(n->IsConstant()); const TensorProto* tensor_proto; - TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor_proto)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor_proto)); DataType dtype; - TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype)); Tensor t(dtype); EXPECT_TRUE(t.FromProto(*tensor_proto)); test::ExpectClose(t, test::AsTensor(values, shape)); @@ -61,9 +61,9 @@ class ConstantFoldingTest : public ::testing::Test { TensorShape shape) { EXPECT_TRUE(n->IsConstant()); const TensorProto* tensor_proto; - TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor_proto)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor_proto)); DataType dtype; - TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype)); Tensor t(dtype); EXPECT_TRUE(t.FromProto(*tensor_proto)); test::ExpectTensorEqual(t, test::AsTensor(values, shape)); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index ec0c9405dd..9e18547af5 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -92,31 +92,28 @@ bool SetTimelineLabel(const Node* node, NodeExecStats* node_stats) { } } } - const NodeDef& def = node->def(); - string text = ""; + const AttrSlice attrs = node->attrs(); + string text; if (IsSend(node)) { string tensor_name; - TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name)); + TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name)); string recv_device; - TF_CHECK_OK(GetNodeAttr(def, "recv_device", &recv_device)); - text = strings::StrCat(memory, def.name(), " = ", def.op(), "(", - tensor_name, " @", recv_device); + TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device)); + text = strings::StrCat(memory, node->name(), " = ", node->type_string(), + "(", tensor_name, " @", recv_device); is_transfer_node = true; } else if (IsRecv(node)) { string tensor_name; - TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name)); + TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name)); string send_device; - TF_CHECK_OK(GetNodeAttr(def, "send_device", &send_device)); - text = strings::StrCat(memory, def.name(), " = ", def.op(), "(", - tensor_name, " @", send_device); + TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device)); + text = strings::StrCat(memory, node->name(), " = ", node->type_string(), + "(", tensor_name, " @", send_device); is_transfer_node = true; } else { - text = strings::StrCat( - memory, def.name(), " = ", def.op(), "(", - str_util::Join( - std::vector(def.input().begin(), def.input().end()), - ", "), - ")"); + text = + strings::StrCat(memory, node->name(), " = ", node->type_string(), "(", + str_util::Join(node->requested_inputs(), ", "), ")"); } node_stats->set_timeline_label(text); return is_transfer_node; @@ -522,7 +519,7 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) { EdgeInfo* dst_edge = item->output_edge_base(); for (auto e : n->out_edges()) { dst_edge->dst_id = e->dst()->id(); - CHECK_LE(e->src_output(), ((int32)0x3FFFFFFF)); // Must fit in 31 bits + CHECK_LE(e->src_output(), 0x3FFFFFFF); // Must fit in 31 bits dst_edge->output_slot = e->src_output(); dst_edge->is_last = false; const int output_slot = dst_edge->output_slot; @@ -640,7 +637,7 @@ Status ExecutorImpl::Initialize() { Status s = params_.create_kernel(n->def(), &item->kernel); if (!s.ok()) { item->kernel = nullptr; - s = AttachDef(s, n->def()); + s = AttachDef(s, *n); LOG(ERROR) << "Executor failed to create kernel. " << s; return s; } @@ -668,7 +665,7 @@ Status ExecutorImpl::Initialize() { frame_info->nodes->push_back(n); if (IsEnter(n)) { string enter_name; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "frame_name", &enter_name)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name)); EnsureFrameInfo(enter_name)->input_count++; } } @@ -723,7 +720,7 @@ Status InferAllocAttr(const Node* n, const Node* dst, // so these two cases are not mutually exclusive. if (IsRecv(n)) { string src_name; - s = GetNodeAttr(n->def(), "send_device", &src_name); + s = GetNodeAttr(n->attrs(), "send_device", &src_name); if (!s.ok()) return s; DeviceNameUtils::ParsedName parsed_src_name; if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) { @@ -748,7 +745,7 @@ Status InferAllocAttr(const Node* n, const Node* dst, } if (IsSend(dst)) { string dst_name; - s = GetNodeAttr(dst->def(), "recv_device", &dst_name); + s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name); if (!s.ok()) return s; DeviceNameUtils::ParsedName parsed_dst_name; if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) { @@ -1361,7 +1358,7 @@ Status ExecutorImpl::BuildControlFlowInfo(const Graph* g, if (IsEnter(curr_node)) { // Enter a child frame. TF_RETURN_IF_ERROR( - GetNodeAttr(curr_node->def(), "frame_name", &frame_name)); + GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name)); parent = curr_node; } else if (IsExit(curr_node)) { // Exit to the parent frame. @@ -1555,8 +1552,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { if (vlog_) { VLOG(1) << "Process node: " << id << " step " << params.step_id << " " - << SummarizeNodeDef(node->def()) - << " is dead: " << tagged_node.is_dead; + << SummarizeNode(*node) << " is dead: " << tagged_node.is_dead; } Entry* input_tensors = GetInputTensors(input_frame, input_iter); @@ -1610,7 +1606,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { if (vlog_) { VLOG(2) << this << " Async kernel done: " - << SummarizeNodeDef(state->item->node->def()); + << SummarizeNode(*state->item->node); } if (stats) nodestats::SetOpEnd(stats); EntryVector outputs; @@ -1811,7 +1807,7 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, // tensor value at i-th output. if (!IsSwitch(node) && !IsRecv(node)) { s.Update(errors::Internal("Missing ", i, "-th output from ", - SummarizeNodeDef(node->def()))); + SummarizeNode(*node))); } } else { Entry* out = &((*outputs)[i]); @@ -1878,7 +1874,7 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, DataTypeString(dtype), " does not match declared output type ", DataTypeString(item.output_type(i)), - " for node ", SummarizeNodeDef(node->def()))); + " for node ", SummarizeNode(*node))); } } if (!val.is_ref()) { @@ -1915,7 +1911,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, &impl_->gview_, input_iter, ready); } else if (item->is_enter) { bool is_constant; - Status s = GetNodeAttr(node->def(), "is_constant", &is_constant); + Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant); DCHECK(s.ok()) << s; FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame); output_iter = 0; @@ -2241,7 +2237,7 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, FrameState** child) { // Get the child frame name. string enter_name; - Status s = GetNodeAttr(node->def(), "frame_name", &enter_name); + Status s = GetNodeAttr(node->attrs(), "frame_name", &enter_name); DCHECK(s.ok()) << s; const string child_name = MakeFrameName(frame, iter, enter_name); @@ -2259,7 +2255,7 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, if (vlog_) VLOG(2) << "Create frame: " << child_name; int parallel_iters; - s = GetNodeAttr(node->def(), "parallel_iterations", ¶llel_iters); + s = GetNodeAttr(node->attrs(), "parallel_iterations", ¶llel_iters); DCHECK(s.ok()) << s; FrameState* temp = new FrameState(impl_, parallel_iters); temp->frame_name = child_name; diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 4be22f8260..996a8a9b3d 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -150,8 +150,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { ~FunctionLibraryRuntimeImpl() override; - Status Instantiate(const string& function_name, - const InstantiateAttrValueMap& attrs, + Status Instantiate(const string& function_name, AttrSlice attrs, Handle* handle) override; const FunctionBody* GetFunctionBody(Handle handle) override; @@ -208,8 +207,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { }; std::vector items_; - Status FunctionDefToBody(const FunctionDef& fdef, - const InstantiateAttrValueMap& attrs, + Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs, FunctionBody** fbody); Status CreateItem(Handle handle, Item** item); Status GetOrCreateItem(Handle handle, Item** item); @@ -324,7 +322,7 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, // Try to instantiate this function for the func/attr. Maybe its // cached already. Handle handle; - TF_RETURN_IF_ERROR(Instantiate(ndef.op(), ndef.attr(), &handle)); + TF_RETURN_IF_ERROR(Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); const FunctionBody* fbody = GetFunctionBody(handle); CHECK_NOTNULL(fbody); @@ -355,9 +353,9 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, return s; } -Status FunctionLibraryRuntimeImpl::FunctionDefToBody( - const FunctionDef& fdef, const InstantiateAttrValueMap& attrs, - FunctionBody** fbody) { +Status FunctionLibraryRuntimeImpl::FunctionDefToBody(const FunctionDef& fdef, + AttrSlice attrs, + FunctionBody** fbody) { // Instantiates the function template into a graph def. InstantiationResult result; TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig_, &result)); @@ -390,11 +388,13 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( // TODO(josh11b): Should filter out the attrs from func that aren't used // by the gradient function. TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef)); - TF_RETURN_IF_ERROR(FunctionDefToBody(grad_fdef, func.attr(), g_body)); + TF_RETURN_IF_ERROR( + FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), g_body)); } else { // f is a user-defined function. Handle f_handle; - TF_RETURN_IF_ERROR(Instantiate(func.name(), func.attr(), &f_handle)); + TF_RETURN_IF_ERROR( + Instantiate(func.name(), AttrSlice(&func.attr()), &f_handle)); const FunctionBody* f_body = GetFunctionBody(f_handle); CHECK_NOTNULL(f_body); *g_body = SymbolicGradient(*f_body); @@ -402,9 +402,9 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( return Status::OK(); } -Status FunctionLibraryRuntimeImpl::Instantiate( - const string& function_name, const InstantiateAttrValueMap& attrs, - Handle* handle) { +Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name, + AttrSlice attrs, + Handle* handle) { const string key = Canonicalize(function_name, attrs); { mutex_lock l(mu_); @@ -417,7 +417,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate( Status s; FunctionBody* fbody = nullptr; if (function_name == kGradientOp) { - const AttrValue* f = gtl::FindOrNull(attrs, kFuncAttr); + const AttrValue* f = attrs.Find(kFuncAttr); if (f == nullptr) { return errors::InvalidArgument("SymbolicGradient is missing attr: f"); } @@ -427,7 +427,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate( } const string grad = lib_def_->FindGradient(func.name()); if (!grad.empty()) { - return Instantiate(grad, func.attr(), handle); + return Instantiate(grad, AttrSlice(&func.attr()), handle); } TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, &fbody)); } else { @@ -989,13 +989,12 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) { for (Node* node : graph->nodes()) { VLOG(3) << "Expanding " << node->DebugString(); bool noinline; - if (fld->GetAttr(node->def(), kNoInlineAttr, &noinline).ok() && noinline) { + if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) { VLOG(3) << "noinline: " << node->DebugString(); continue; } FunctionLibraryRuntime::Handle handle; - Status s = - lib->Instantiate(node->type_string(), node->def().attr(), &handle); + Status s = lib->Instantiate(node->type_string(), node->attrs(), &handle); if (!s.ok()) { // Either "node" is a primitive op, or the instantiation failed. if (errors::IsNotFound(s)) { @@ -1103,7 +1102,7 @@ FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t, continue; } int index; - TF_CHECK_OK(GetNodeAttr(n->def(), "index", &index)); + TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index)); CHECK_LE(0, index); CHECK_LT(index, node_vec->size()); (*node_vec)[index] = n; diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index dfa1ed8a7e..e27fc3898d 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { +namespace { typedef FunctionDefHelper FDH; @@ -58,13 +59,29 @@ void HasError(const Status& s, const string& substr) { << s << ", expected substring " << substr; } +// A helper class to make AttrSlice from initializer lists +class Attrs { + public: + Attrs(const std::initializer_list< // NOLINT(runtime/explicit) + std::pair>& attrs) { + for (const auto& aval : attrs) { + map_.insert({aval.first, aval.second.proto}); + } + } + + operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) + + private: + AttrValueMap map_; +}; + class FunctionTest : public ::testing::Test { protected: FunctionTest() : device_(DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0")) {} - void Create(const FunctionDef& fdef, InstantiateAttrValueSlice attrs) { + void Create(const FunctionDef& fdef, Attrs attrs) { exec_ = nullptr; InstantiationResult result; TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result)); @@ -151,8 +168,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { fdef_lib_ = lib_def_->ToProto(); } - Status Run(const string& name, InstantiateAttrValueSlice attrs, - const std::vector& args, std::vector rets) { + Status Run(const string& name, Attrs attrs, const std::vector& args, + std::vector rets) { FunctionLibraryRuntime::Handle handle; Status status = lib_->Instantiate(name, attrs, &handle); if (!status.ok()) { @@ -188,8 +205,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return Status::OK(); } - std::unique_ptr GetFuncBody(const string& name, - InstantiateAttrValueSlice attrs) { + std::unique_ptr GetFuncBody(const string& name, Attrs attrs) { FunctionLibraryRuntime::Handle handle; Status status = lib_->Instantiate(name, attrs, &handle); if (!status.ok()) { @@ -203,8 +219,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return ret; } - std::unique_ptr GetGradBody(const string& func, - InstantiateAttrValueSlice attrs) { + std::unique_ptr GetGradBody(const string& func, Attrs attrs) { FunctionLibraryRuntime::Handle handle; Status status = lib_->Instantiate(func, attrs, &handle); if (!status.ok()) { @@ -615,13 +630,14 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) { // Instantiating "XTimesTwo" should fail. FunctionLibraryRuntime::Handle handle; - HasError(lib_->Instantiate("XTimesTwo", {{"T", DT_FLOAT}}, &handle), + HasError(lib_->Instantiate("XTimesTwo", Attrs({{"T", DT_FLOAT}}), &handle), "Not found: type attr not found"); // But XTimesFour and XTimes16 instantiation should succeed. Only // when they run, they fail because XTimesTwo is bad. - TF_CHECK_OK(lib_->Instantiate("XTimesFour", {{"T", DT_FLOAT}}, &handle)); - TF_CHECK_OK(lib_->Instantiate("XTimes16", {{"T", DT_FLOAT}}, &handle)); + TF_CHECK_OK( + lib_->Instantiate("XTimesFour", Attrs({{"T", DT_FLOAT}}), &handle)); + TF_CHECK_OK(lib_->Instantiate("XTimes16", Attrs({{"T", DT_FLOAT}}), &handle)); auto x = test::AsTensor({1, 2, 3, 4}); Tensor y; @@ -928,8 +944,7 @@ bool DoNothing(Graph* g) { return false; } GraphDef Optimize(const std::function& pass, const FunctionDef& fdef) { InstantiationResult result; - InstantiateAttrValueMap empty; - TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result)); + TF_CHECK_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); std::unique_ptr g(new Graph(OpRegistry::Global())); GraphConstructorOptions opts; opts.allow_internal_ops = true; @@ -1248,4 +1263,5 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) { TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func)); } +} // end namespace } // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index 72bc37d435..4e14e6fe1a 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -103,13 +103,13 @@ void Benchmark::Run(int iters) { RunWithArgs({}, {}, iters); } string GetRendezvousKey(const Node* node) { string send_device; - TF_CHECK_OK(GetNodeAttr(node->def(), "send_device", &send_device)); + TF_CHECK_OK(GetNodeAttr(node->attrs(), "send_device", &send_device)); string recv_device; - TF_CHECK_OK(GetNodeAttr(node->def(), "recv_device", &recv_device)); + TF_CHECK_OK(GetNodeAttr(node->attrs(), "recv_device", &recv_device)); string tensor_name; - TF_CHECK_OK(GetNodeAttr(node->def(), "tensor_name", &tensor_name)); + TF_CHECK_OK(GetNodeAttr(node->attrs(), "tensor_name", &tensor_name)); uint64 send_device_incarnation; - TF_CHECK_OK(GetNodeAttr(node->def(), "send_device_incarnation", + TF_CHECK_OK(GetNodeAttr(node->attrs(), "send_device_incarnation", reinterpret_cast(&send_device_incarnation))); return Rendezvous::CreateKey(send_device, send_device_incarnation, recv_device, tensor_name, FrameAndIter(0, 0)); diff --git a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc index ffbfbc74f1..bbd38a2e07 100644 --- a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc +++ b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc @@ -49,11 +49,11 @@ class ParallelConcatRemovePass : public GraphOptimizationPass { } } for (Node* n : matches) { - AttrSlice n_attrs(n->def()); + AttrSlice n_attrs = n->attrs(); auto base_make_node = [n, g, &n_attrs](const string& op, const string& name) { NodeBuilder node_builder(name, op); - node_builder.Device(n->def().device()); + node_builder.Device(n->requested_device()); string colo; if (GetNodeAttr(n_attrs, "_class", &colo).ok()) { node_builder.Attr("_class", colo); diff --git a/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc b/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc index c179e94c36..b40924ef3a 100644 --- a/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc +++ b/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc @@ -55,7 +55,7 @@ class ResourceVariableReadPass : public GraphOptimizationPass { } for (Node* read : matches) { DataType dtype; - TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(read->def()), "dtype", &dtype)); + TF_RETURN_IF_ERROR(GetNodeAttr(read->attrs(), "dtype", &dtype)); std::vector in_control_edges; std::vector> in_edges; for (const Edge* edge : read->in_edges()) { diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc index 8808d52d3e..ae225e8b35 100644 --- a/tensorflow/core/common_runtime/simple_placer.cc +++ b/tensorflow/core/common_runtime/simple_placer.cc @@ -76,7 +76,7 @@ void ColocationGroups(const Node& node, std::vector class_specs; // TODO(vrv): We should consider adding a GetNodeAttr that returns a // StringPiece, to avoid a copy. - if (!GetNodeAttrSimple(node.def(), kColocationAttrNameStringPiece, + if (!GetNodeAttrSimple(node.attrs(), kColocationAttrNameStringPiece, &class_specs)) { // No attribute value is equivalent to the empty colocation_group. *colocation_groups = { @@ -329,7 +329,7 @@ class ColocationGraph { AddDebugInfo(node_root, &debug_info); DeviceNameUtils::ParsedName specified_device_name; - if (DeviceNameUtils::ParseFullName(node->def().device(), + if (DeviceNameUtils::ParseFullName(node->requested_device(), &specified_device_name) && specified_device_name == members_[node_root].device_name) { // The specified device and merged set device match, and @@ -348,27 +348,27 @@ class ColocationGraph { std::sort(device_names.begin(), device_names.end()); return errors::InvalidArgument( - "Operation was explicitly assigned to ", node->def().device(), - " but available devices are [ ", + "Operation was explicitly assigned to ", + node->requested_device(), " but available devices are [ ", str_util::Join(device_names, ", "), " ]. Make sure ", "the device specification refers to a valid device."); } else if (specified_device_name.has_type) { return errors::InvalidArgument( "Could not satisfy explicit device specification '", - node->def().device(), "' because no supported kernel for ", + node->requested_device(), "' because no supported kernel for ", specified_device_name.type, " devices is available.", debug_info); } else { return errors::InvalidArgument( "Could not satisfy explicit device specification '", - node->def().device(), debug_info); + node->requested_device(), debug_info); } } else { // The specified device may be a valid device but the // merged set device is different, so print both. return errors::InvalidArgument( "Could not satisfy explicit device specification '", - node->def().device(), + node->requested_device(), "' because the node was colocated with a group of nodes that " "required incompatible device '", DeviceNameUtils::ParsedNameToString( @@ -513,7 +513,7 @@ class ColocationGraph { return errors::Internal("Assigned device '", node.assigned_device_name(), "' does not have registered OpKernel support " "for ", - node.def().op()); + node.type_string()); } else { // This node has not yet been assigned to a device, so we // calculate any constraints due to the set of registered @@ -527,25 +527,25 @@ class ColocationGraph { registered_device_types.insert(d->device_type()); } return errors::InvalidArgument( - "No OpKernel was registered to support Op '", node.def().op(), + "No OpKernel was registered to support Op '", node.type_string(), "' with these attrs. Registered devices: [", str_util::Join(registered_device_types, ","), "], Registered kernels:\n", - KernelsRegisteredForOp(node.def().op())); + KernelsRegisteredForOp(node.type_string())); } // If the NodeDef contains a device, then we interpret it as a // (partial) device specification. - if (!node.def().device().empty()) { + if (!node.requested_device().empty()) { // The user has specified a device in the NodeDef, try to find a // valid device matching their specification in the set of // devices. // NOTE: The full name may specify a device that is not in // n.supported_device_types(), but we check that in AssignDevice(). - if (!DeviceNameUtils::ParseFullName(node.def().device(), + if (!DeviceNameUtils::ParseFullName(node.requested_device(), &member->device_name)) { return errors::InvalidArgument("Malformed device specification '", - node.def().device(), "'"); + node.requested_device(), "'"); } } } @@ -644,7 +644,7 @@ Status SimplePlacer::Run() { continue; } status = colocation_graph.AddNode(*node); - if (!status.ok()) return AttachDef(status, node->def()); + if (!status.ok()) return AttachDef(status, *node); } // 2. Enumerate the constraint edges, and use them to update the disjoint @@ -707,7 +707,7 @@ Status SimplePlacer::Run() { "be on the same device), but the two nodes " "were assigned two different devices: ", status.error_message()), - node->def()); + *node); } } } @@ -749,7 +749,7 @@ Status SimplePlacer::Run() { return AttachDef( errors::InvalidArgument("Cannot assign a device for operation '", node->name(), "': ", status.error_message()), - node->def()); + *node); } // Returns the first device in sorted devices list so we will always @@ -791,7 +791,7 @@ Status SimplePlacer::Run() { return AttachDef( errors::InvalidArgument("Cannot assign a device for operation '", node->name(), "': ", status.error_message()), - node->def()); + *node); } string assigned_device = devices[0]->name(); diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc index a222dc75d7..f8f3d2ae50 100644 --- a/tensorflow/core/debug/debug_graph_utils.cc +++ b/tensorflow/core/debug/debug_graph_utils.cc @@ -223,19 +223,16 @@ Status DebugNodeInserter::InsertNodes( void DebugNodeInserter::DeparallelizeWhileLoops(Graph* graph, Device* device) { for (Node* node : graph->nodes()) { if (node->IsEnter()) { - for (const auto& attr : node->def().attr()) { - if (attr.first == "parallel_iterations") { - if (attr.second.i() > 1) { - LOG(INFO) << "For debugging, tfdbg is changing the " - << "parallel_iterations attribute of the Enter/RefEnter " - << "node \"" << node->name() << "\" on device \"" - << device->name() << "\" from " << attr.second.i() - << " to 1. (This does not affect subsequent non-debug " - << "runs.)"; - node->AddAttr("parallel_iterations", 1); - } - break; - } + const AttrValue* parallel_iterations = + node->attrs().Find("parallel_iterations"); + if (parallel_iterations && parallel_iterations->i() > 1) { + LOG(INFO) << "For debugging, tfdbg is changing the " + << "parallel_iterations attribute of the Enter/RefEnter " + << "node \"" << node->name() << "\" on device \"" + << device->name() << "\" from " << parallel_iterations->i() + << " to 1. (This does not affect subsequent non-debug " + << "runs.)"; + node->AddAttr("parallel_iterations", 1); } } } diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 73758ade03..dddff4dce4 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -188,7 +188,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { const RunState* run_state, SimpleGraphExecutionState* execution_state); - string DetailText(const NodeDef& def, const NodeExecStats& ns) { + string DetailText(const Node& node, const NodeExecStats& ns) { int64 tot = 0; for (auto& no : ns.output()) { tot += no.tensor_description().allocation_description().requested_bytes(); @@ -197,12 +197,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { if (tot >= 0.1 * 1048576.0) { bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0); } - return strings::StrCat( - bytes, def.name(), " = ", def.op(), "(", - str_util::Join( - std::vector(def.input().begin(), def.input().end()), - ", "), - ")"); + return strings::StrCat(bytes, node.name(), " = ", node.type_string(), "(", + str_util::Join(node.requested_inputs(), ", "), ")"); } private: @@ -790,7 +786,7 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats( if (!ns.timeline_label().empty()) { details = ns.timeline_label(); } else if (found_node_in_graph) { - details = DetailText(node->def(), ns); + details = DetailText(*node, ns); } else { // Leave details string empty } diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index f46bb6e2ed..186095201d 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/function.pb_text.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -44,12 +45,11 @@ namespace { // Otherwise (arg_def is a simple type T), *is_type_list is set to // false, and *dtypes is set to a single element vector, whose only // element is T. -Status ArgNumType(const InstantiateAttrValueMap& attrs, - const OpDef::ArgDef& arg_def, bool* is_type_list, - DataTypeVector* dtypes) { +Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, + bool* is_type_list, DataTypeVector* dtypes) { dtypes->clear(); if (!arg_def.type_list_attr().empty()) { - const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_list_attr()); + const AttrValue* v = attrs.Find(arg_def.type_list_attr()); if (v == nullptr) { return errors::NotFound("type attr not found: ", arg_def.type_list_attr()); @@ -64,7 +64,7 @@ Status ArgNumType(const InstantiateAttrValueMap& attrs, *is_type_list = false; int num = 1; if (!arg_def.number_attr().empty()) { - const AttrValue* v = gtl::FindOrNull(attrs, arg_def.number_attr()); + const AttrValue* v = attrs.Find(arg_def.number_attr()); if (v == nullptr) { return errors::NotFound("type attr not found: ", arg_def.type_attr()); } @@ -77,7 +77,7 @@ Status ArgNumType(const InstantiateAttrValueMap& attrs, } else if (arg_def.type_attr().empty()) { dtype = DT_INVALID; } else { - const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_attr()); + const AttrValue* v = attrs.Find(arg_def.type_attr()); if (v == nullptr) { return errors::NotFound("type attr not found: ", arg_def.type_attr()); } @@ -92,18 +92,17 @@ void AddAttr(const string& name, const T& val, NodeDef* ndef) { SetAttrValue(val, &((*ndef->mutable_attr())[name])); } -Status ValidateSignatureWithAttrs(const OpDef& sig, - const InstantiateAttrValueMap& attr_values) { +Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) { // attr_values should specify all attrs defined in fdef. for (const auto& a : sig.attr()) { - auto const iter = attr_values.find(a.name()); - if (iter == attr_values.end()) { + const AttrValue* v = attr_values.Find(a.name()); + if (!v) { return errors::NotFound("Attr ", a.name(), " is not found from ", SummarizeOpDef(sig)); } - Status status = AttrValueHasType(iter->second, a.type()); + Status status = AttrValueHasType(*v, a.type()); if (!status.ok()) { - errors::AppendToMessage(&status, "for attr '", iter->first, "'"); + errors::AppendToMessage(&status, "for attr '", a.name(), "'"); return status; } } @@ -146,7 +145,7 @@ class FunctionInstantiationHelper { // Builds index for nodes that can be used as node's input arguments. Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, - const InstantiateAttrValueMap& attr_values) { + AttrSlice attr_values) { bool is_type_list; DataTypeVector dtypes; TF_RETURN_IF_ERROR( @@ -175,8 +174,7 @@ class FunctionInstantiationHelper { return Status::OK(); } - Status BuildNodeOutputIndex(const NodeDef& node, - const InstantiateAttrValueMap& attrs, + Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs, const int arg_index) { const OpDef* node_sig = nullptr; TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig)); @@ -206,8 +204,7 @@ class FunctionInstantiationHelper { return Status::OK(); } - Status InstantiateNode(const NodeDef& fnode, - const InstantiateAttrValueMap& attrs) { + Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) { const OpDef* fnode_sig = nullptr; TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig)); NodeDef* gnode = AddNode(fnode.name()); @@ -295,7 +292,7 @@ class FunctionInstantiationHelper { } Status AddReturnNode( - const OpDef::ArgDef& ret_def, const InstantiateAttrValueMap& attrs, + const OpDef::ArgDef& ret_def, AttrSlice attrs, const ::tensorflow::protobuf::Map& ret_map, int* ret_index) { auto ret_iter = ret_map.find(ret_def.name()); @@ -604,7 +601,7 @@ string Print(const GraphDef& gdef) { Status AddDefaultAttrs(const string& op, const GetFunctionSignature& get_function, - InstantiateAttrValueMap* attrs) { + AttrValueMap* attrs) { const OpDef* op_def = nullptr; TF_RETURN_IF_ERROR(get_function(op, &op_def)); AttrSlice attr_slice(attrs); @@ -620,8 +617,7 @@ Status AddDefaultAttrs(const string& op, } // end namespace -Status InstantiateFunction(const FunctionDef& fdef, - const InstantiateAttrValueMap& attr_values, +Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, GetFunctionSignature get_function, InstantiationResult* result) { VLOG(3) << "Instantiation Function: " << Print(fdef); @@ -639,19 +635,17 @@ Status InstantiateFunction(const FunctionDef& fdef, } } - auto substitute = [&attr_values](const string& name, AttrValue* val) { - auto iter = attr_values.find(name); - if (iter == attr_values.end()) { - return false; - } else { - *val = iter->second; + auto substitute = [attr_values](StringPiece name, AttrValue* val) { + if (const AttrValue* v = attr_values.Find(name)) { + *val = *v; return true; } + return false; }; // Makes a copy of all attrs in fdef and substitutes placeholders. // After this step, every attr is bound to a concrete value. - std::vector node_attrs; + std::vector node_attrs; node_attrs.resize(fdef.node_def_size()); for (int i = 0; i < fdef.node_def_size(); ++i) { for (auto attr : fdef.node_def(i).attr()) { @@ -668,7 +662,7 @@ Status InstantiateFunction(const FunctionDef& fdef, } for (int i = 0; i < fdef.node_def_size(); ++i) { - s = helper.BuildNodeOutputIndex(fdef.node_def(i), node_attrs[i], + s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]), result->gdef.node_size() + i); if (!s.ok()) { errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i))); @@ -677,7 +671,7 @@ Status InstantiateFunction(const FunctionDef& fdef, } // Emits one gdef.node for each fdef.node_def. for (int i = 0; i < fdef.node_def_size(); ++i) { - s = helper.InstantiateNode(fdef.node_def(i), node_attrs[i]); + s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i])); if (!s.ok()) { errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i))); return s; @@ -748,8 +742,7 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) { return true; } -string Canonicalize(const string& funcname, - const InstantiateAttrValueMap& attrs) { +string Canonicalize(const string& funcname, AttrSlice attrs) { std::vector entries; entries.reserve(attrs.size()); for (auto p : attrs) { @@ -953,8 +946,7 @@ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or // Foo's attributes. const NameAttrList* forward_func_attrs; - if (!GetNodeAttr(AttrSlice(&ndef.attr()), kFuncAttr, &forward_func_attrs) - .ok()) { + if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) { return nullptr; } const string& func_name = forward_func_attrs->name(); @@ -981,34 +973,30 @@ FunctionDefLibrary FunctionLibraryDefinition::ToProto() const { return lib; } -Status InstantiateFunction(const FunctionDef& fdef, - InstantiateAttrValueSlice attr_values, - GetFunctionSignature get_function, - InstantiationResult* result) { - InstantiateAttrValueMap m; - for (const auto& aval : attr_values) { - m.insert({aval.first, aval.second.proto}); +template +Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, + const string& attr, T* value) const { + const FunctionDef* fdef = GetAttrImpl(ndef); + if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) { + return Status::OK(); } - return InstantiateFunction(fdef, m, std::move(get_function), result); + return errors::InvalidArgument("Attr ", attr, " is not defined."); } -string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs) { - InstantiateAttrValueMap m; - for (const auto& aval : attrs) { - m.insert({aval.first, aval.second.proto}); - } - return Canonicalize(funcname, m); +template +Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr, + T* value) const { + return GetAttr(node.def(), attr, value); } -Status FunctionLibraryRuntime::Instantiate(const string& function_name, - InstantiateAttrValueSlice attrs, - Handle* handle) { - InstantiateAttrValueMap m; - for (const auto& aval : attrs) { - m.insert({aval.first, aval.second.proto}); - } - return Instantiate(function_name, m, handle); -} +#define GET_ATTR(T) \ + template Status FunctionLibraryDefinition::GetAttr(const Node&, \ + const string&, T*) const; \ + template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \ + const string&, T*) const; +GET_ATTR(string) +GET_ATTR(bool) +#undef GET_ATTR void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) { if (val.size() >= 2 && val[0] == '$') { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 210e5b949a..188c3855c6 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -36,6 +36,7 @@ class CancellationManager; class OpKernel; class ResourceMgr; class ScopedStepContainer; +class Node; // FunctionDefHelper::Create is a convenient helper to construct a // FunctionDef proto. @@ -190,11 +191,6 @@ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) { // InstantiateFunction calls "get_function" to find signatures of other // functions and primitive ops. -// Placeholders in "fdef" is substituted based on "attr_values" here. -typedef ::tensorflow::protobuf::Map InstantiateAttrValueMap; -typedef gtl::ArraySlice> - InstantiateAttrValueSlice; - // GetFunctionSignature(func name, opdef) returns OK if the func name is found // and opdef is filled with a pointer to the corresponding signature // (a OpDef proto). Otherwise, returns an error. @@ -206,12 +202,7 @@ struct InstantiationResult { DataTypeVector ret_types; GraphDef gdef; }; -Status InstantiateFunction(const FunctionDef& fdef, - const InstantiateAttrValueMap& attr_values, - GetFunctionSignature get_function, - InstantiationResult* result); -Status InstantiateFunction(const FunctionDef& fdef, - InstantiateAttrValueSlice attr_values, +Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, GetFunctionSignature get_function, InstantiationResult* result); @@ -241,9 +232,7 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2); // space. But it may be change as the implementation // evolves. Therefore, it should not be persisted or compared across // address spaces. -string Canonicalize(const string& funcname, - const InstantiateAttrValueMap& attrs); -string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs); +string Canonicalize(const string& funcname, AttrSlice attrs); // Represents a function call frame. I.e., the data structure used to // pass arguments to a function and retrieve its results. @@ -330,9 +319,16 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // Given a node def 'ndef', inspects attributes of the callee // function to derive the attribute 'value' for 'attr'. Returns OK // iff the attribute is given by the function's definition. + // TODO(irving): Remove; keep only the const Node& version. template Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const; + // Given a node, inspects attributes of the callee function to derive the + // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the + // function's definition. + template + Status GetAttr(const Node& node, const string& attr, T* value) const; + // Returns a proto representation of the state of this function library. FunctionDefLibrary ToProto() const; @@ -375,11 +371,8 @@ class FunctionLibraryRuntime { // Returns OK and fills in "handle" if the instantiation succeeds. // Otherwise returns an error and "handle" is undefined. typedef uint64 Handle; - virtual Status Instantiate(const string& function_name, - const InstantiateAttrValueMap& attrs, + virtual Status Instantiate(const string& function_name, AttrSlice attrs, Handle* handle) = 0; - Status Instantiate(const string& function_name, - InstantiateAttrValueSlice attrs, Handle* handle); // Returns the function body for the instantiated function given its // handle 'h'. Returns nullptr if "h" is not found. @@ -506,17 +499,15 @@ bool RegisterOp(const string& op, Creator func); Status GetOpGradientCreator(const string& op, Creator* creator); }; -// Implementation details. - -template -Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, - const string& attr, T* value) const { - const FunctionDef* fdef = GetAttrImpl(ndef); - if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) { - return Status::OK(); - } - return errors::InvalidArgument("Attr ", attr, " is not defined."); -} +// Declare explicit instantiations of GetAttr +#define GET_ATTR(T) \ + extern template Status FunctionLibraryDefinition::GetAttr( \ + const Node&, const string&, T*) const; \ + extern template Status FunctionLibraryDefinition::GetAttr( \ + const NodeDef&, const string&, T*) const; +GET_ATTR(string) +GET_ATTR(bool) +#undef GET_ATTR } // end namespace tensorflow diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index 07462a575e..c83ecf4e5e 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -29,6 +29,24 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace tensorflow { +namespace { + +// A helper class to make AttrSlice from initializer lists +class Attrs { + public: + Attrs(const std::initializer_list< // NOLINT(runtime/explicit) + std::pair> + attrs) { + for (const auto& aval : attrs) { + map_.insert({aval.first, aval.second.proto}); + } + } + + operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) + + private: + AttrValueMap map_; +}; typedef FunctionDefHelper FDH; @@ -46,8 +64,6 @@ y: A scalar in type T. )doc"); -static InstantiateAttrValueMap kNoAttrs; - TEST(TFunc, SquarePlusOne) { auto fdef = FDH::Create( // Name @@ -81,7 +97,8 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) { // Instantiate one with T=float InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result)); + TF_ASSERT_OK( + InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result)); const char* e2 = R"P( (x:float) -> (y:float) { a = Square[T=float](x) @@ -126,7 +143,8 @@ ControlDep(x:int32) -> (y:int32) { // Instantiate one with T=float InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result)); + TF_ASSERT_OK( + InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result)); const char* e2 = R"P( (x:int32) -> (y:int32) { a = Identity[T=int32](x) @@ -171,8 +189,7 @@ BackCompat() -> (y:float) { EXPECT_EQ(DebugString(fdef), e); InstantiationResult result; - TF_ASSERT_OK( - InstantiateFunction(fdef, InstantiateAttrValueMap{}, GetOpSig, &result)); + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); // Should get T=float from Op's default. const char* e2 = R"P( () -> (a:float) { @@ -209,7 +226,7 @@ NTimesT(x:float, y:float) -> (z:float) { EXPECT_EQ(DebugString(fdef), e); InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); const char* e2 = R"P( (x:float, y:float) -> (a:float) { a = AddN[N=2, T=float](x, y) @@ -272,8 +289,8 @@ AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) { // Instantiate one with T=float InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, {{"N", 3}, {"T", DT_FLOAT}}, GetOpSig, - &result)); + TF_ASSERT_OK(InstantiateFunction(fdef, Attrs({{"N", 3}, {"T", DT_FLOAT}}), + GetOpSig, &result)); const char* e2 = R"P( (x_0:float, x_1:float, x_2:float) -> (y:float) { a = Map[N=3, T=float, U=float, func=Square[T=float]](x_0, x_1, x_2) @@ -315,7 +332,7 @@ ControlDeps(x:float) -> () { EXPECT_EQ(DebugString(fdef), e); InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); const char* e2 = R"P( (x:float) -> () { a = One[T=float]() @ x @@ -395,7 +412,7 @@ Test(i:float) -> (o:float) { EXPECT_EQ(DebugString(fdef), e); InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); const char* e2 = R"P( (i:float) -> (o:float) { zero = Const[dtype=int32, value=Tensor]() @@ -467,7 +484,7 @@ MySelect(x:float) -> (z:float) { EXPECT_EQ(DebugString(fdef), e); InstantiationResult result; - TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result)); const char* e2 = R"P( (x:float) -> (z:float) { y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x) @@ -488,8 +505,9 @@ TEST(InstantiateErrors, Not_Sufficient_Attrs) { auto fdef = FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); InstantiationResult result; - HasError(InstantiateFunction(fdef, {{"U", DT_FLOAT}}, GetOpSig, &result), - "Attr T is not found from "); + HasError( + InstantiateFunction(fdef, Attrs({{"U", DT_FLOAT}}), GetOpSig, &result), + "Attr T is not found from "); } #if 0 // TODO(josh11b): Enable this test once having an extra attr is an error. @@ -497,7 +515,7 @@ TEST(InstantiateErrors, Too_Many_Attrs) { auto fdef = FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); InstantiationResult result; - HasError(InstantiateFunction(fdef, {{"T", DT_INT32}, {"U", DT_FLOAT}}, + HasError(InstantiateFunction(fdef, Attrs({{"T", DT_INT32}, {"U", DT_FLOAT}}), GetOpSig, &result), "Attr U is not found in "); } @@ -508,7 +526,7 @@ TEST(InstantiateErrors, AttrValue_Value_Placeholder) { FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); InstantiationResult result; HasError( - InstantiateFunction(fdef, {{"T", "$bad"}}, GetOpSig, &result), + InstantiateFunction(fdef, Attrs({{"T", "$bad"}}), GetOpSig, &result), "AttrValue had value with unexpected type 'placeholder'\n\tfor attr 'T'"); } @@ -518,14 +536,15 @@ TEST(InstantiateErrors, Unbounded_Attr) { {{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}}, }); InstantiationResult result; - HasError(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result), - "Failed to bind all placeholders"); + HasError( + InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result), + "Failed to bind all placeholders"); } TEST(InstantiateErrors, DupArgs) { auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Duplicated arg name"); } @@ -536,7 +555,7 @@ TEST(InstantiateErrors, Dup_Node_Names) { {{"y"}, "One", {}, {{"T", DT_FLOAT}}}, }); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Duplicated ret name"); } @@ -547,7 +566,7 @@ TEST(InstantiateErrors, Node_Arg_Notfound) { }, {}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "input z is not found"); } @@ -557,7 +576,7 @@ TEST(InstantiateErrors, Node_Arg_TypeMismatch) { {{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}}, }); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "input x[0] expected type int32 != float, the type of x[0]"); } @@ -568,7 +587,7 @@ TEST(InstantiateErrors, Node_Arg_ControlMissing) { {{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}}, }); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "input[2] == '^z', is not found."); } @@ -579,7 +598,7 @@ TEST(InstantiateErrors, FuncRet_Missing) { }, {}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Return y missing"); } @@ -590,7 +609,7 @@ TEST(InstantiateErrors, FuncRet_NotFound) { }, {{"y", "z"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Return y -> z is not found"); } @@ -601,7 +620,7 @@ TEST(InstantiateErrors, FuncRet_NameMismatch) { }, {{"z", "x:y:0"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Return y missing"); } @@ -613,7 +632,7 @@ TEST(InstantiateErrors, FuncRet_NameMismatch) { // }, // {{"y", "x:y:0"}, {"z", "x:y:0"}}); // InstantiationResult result; -// HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), +// HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), // "ret is not found"); // } @@ -623,7 +642,7 @@ TEST(InstantiateErrors, FuncRet_TypeMismatch) { {{"y"}, "One", {}, {{"T", DT_DOUBLE}}}, }); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Invalid ret types y : float vs. double\n\tIn function output y"); } @@ -649,7 +668,7 @@ TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) { }, {{"y", "y:output"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "type attr not found: out_types"); } @@ -676,7 +695,7 @@ TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) { }, {{"y", "y:output"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Invalid ret types"); } @@ -703,7 +722,7 @@ TEST(InstantiateErrors, TypeList_Missing_Arg) { }, {{"y", "y:output"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "input unknown is not found"); } @@ -724,7 +743,7 @@ TEST(InstantiateErrors, TooManyInputs) { {{"z", "a:sum:0"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Expected input[2] == 'x' to be a control input."); } @@ -745,7 +764,7 @@ TEST(InstantiateErrors, TooFewInputs) { {{"z", "a:sum:0"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Attempt to access beyond input size: 2 >= 2"); } @@ -773,7 +792,7 @@ TEST(InstantiateErrors, TooManyInputsFromArray1) { {{"z", "a:sum:0"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Expected input[1] == 'y' to be a control input."); } @@ -801,7 +820,7 @@ TEST(InstantiateErrors, TooManyInputsFromArray2) { {{"z", "a:sum:0"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "Input a:output too long for inputs"); } @@ -822,7 +841,7 @@ TEST(InstantiateErrors, TypeMismatch) { {{"z", "a:sum:0"}}); InstantiationResult result; - HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result), "input inputs[1] expected type float != int32, the type of y[0]"); } @@ -874,17 +893,17 @@ TEST(FunctionCallFrame, Float_Float_Float) { } TEST(Canonicalize, Basic) { - EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT}, - {"transpose_a", false}, - {"transpose_b", false}}), + EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT}, + {"transpose_a", false}, + {"transpose_b", false}})), "MatMul[T=float,transpose_a=false,transpose_b=false]"); - EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT}, - {"transpose_b", false}, - {"transpose_a", false}}), + EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT}, + {"transpose_b", false}, + {"transpose_a", false}})), "MatMul[T=float,transpose_a=false,transpose_b=false]"); - EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_DOUBLE}, - {"transpose_b", true}, - {"transpose_a", false}}), + EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_DOUBLE}, + {"transpose_b", true}, + {"transpose_a", false}})), "MatMul[T=double,transpose_a=false,transpose_b=true]"); } @@ -1148,4 +1167,5 @@ TEST(FunctionDefsEqualTest, TestFunctionDefsEqual) { EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2)); } +} // end namespace } // end namespace tensorflow diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index 36c0842924..9b737e1f72 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb_text.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/tensor.pb_text.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/scanner.h" @@ -36,18 +37,23 @@ namespace tensorflow { const char* const kColocationAttrName = "_class"; const char* const kColocationGroupPrefix = "loc:@"; +AttrSlice::AttrSlice() : ndef_(nullptr) { + static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap; + attrs_ = kEmptyAttrValueMap; +} + AttrSlice::AttrSlice(const NodeDef& node_def) : ndef_(&node_def), attrs_(&ndef_->attr()) {} AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {} -string SummarizeNodeDef(const NodeDef& node_def) { - string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "["); +static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) { + string ret; // We sort the attrs so the output is deterministic. std::vector attr_names; - attr_names.reserve(node_def.attr().size()); - for (const auto& attr : node_def.attr()) { + attr_names.reserve(attrs.size()); + for (const auto& attr : attrs) { attr_names.push_back(attr.first); } std::sort(attr_names.begin(), attr_names.end()); @@ -55,20 +61,34 @@ string SummarizeNodeDef(const NodeDef& node_def) { for (const string& attr_name : attr_names) { if (!first) strings::StrAppend(&ret, ", "); first = false; - auto iter = node_def.attr().find(attr_name); - strings::StrAppend(&ret, attr_name, "=", SummarizeAttrValue(iter->second)); + strings::StrAppend(&ret, attr_name, "=", + SummarizeAttrValue(*attrs.Find(attr_name))); } // Consider the device to be a final attr with name "_device". - if (!node_def.device().empty()) { + if (!device.empty()) { if (!first) strings::StrAppend(&ret, ", "); first = false; - strings::StrAppend(&ret, "_device=\"", node_def.device(), "\""); + strings::StrAppend(&ret, "_device=\"", device, "\""); } + return ret; +} + +string AttrSlice::SummarizeNode() const { + return ndef_ ? SummarizeNodeDef(*ndef_) + : strings::StrCat( + "[", SummarizeAttrsHelper(*this, StringPiece()), "]"); +} + +string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); } + +string SummarizeNodeDef(const NodeDef& node_def) { + string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "["); + strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device())); strings::StrAppend(&ret, "]("); // Output inputs, including control inputs, verbatim. - first = true; + bool first = true; for (const string& input : node_def.input()) { if (!first) strings::StrAppend(&ret, ", "); first = false; @@ -109,12 +129,28 @@ Status AttrSlice::Find(StringPiece attr_name, // Skip AttachDef for internal attrs since it is a little bit // expensive and it is common for them to correctly not be included // in a NodeDef. - if (!StringPiece(attr_name).starts_with("_") && ndef_) { + if (!attr_name.starts_with("_") && ndef_ != nullptr) { s = AttachDef(s, *ndef_); } return s; } +bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const { + if (size() != other.size()) return false; + + for (const auto& attr : *other.attrs_) { + auto iter = attrs_->find(attr.first); + if (iter == attrs_->end()) return false; + // TODO(irving): Comparing AttrValues by proto is slightly buggy, since + // TensorProto is a nonunique representation of Tensor. This bug will go + // away once AttrSlice switches over to NodeInfo. + iter->second.SerializeToString(&scratch->a); + attr.second.SerializeToString(&scratch->b); + if (scratch->a != scratch->b) return false; + } + return true; +} + // The ... is to allow the caller to inject some value validation code. Use // just ; if no additional validation code is needed. #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \ @@ -341,14 +377,14 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { if (StringPiece(input).starts_with("^")) { seen_control = true; if (input.find(':') != string::npos) { - return errors::InvalidArgument("Control input '", input, - "' must not have ':' in NodeDef: ", - SummarizeNodeDef(node_def)); + return errors::InvalidArgument( + "Control input '", input, + "' must not have ':' in NodeDef: ", SummarizeNodeDef(node_def)); } } else if (seen_control) { - return errors::InvalidArgument("Non-control input '", input, - "' after control input in NodeDef: ", - SummarizeNodeDef(node_def)); + return errors::InvalidArgument( + "Non-control input '", input, + "' after control input in NodeDef: ", SummarizeNodeDef(node_def)); } else { ++num_inputs; } @@ -358,8 +394,8 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { for (const auto& attr : op_def.attr()) { if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) { return errors::InvalidArgument("OpDef has duplicate attr name '", - attr.name(), "': ", - SummarizeOpDef(op_def)); + attr.name(), + "': ", SummarizeOpDef(op_def)); } } for (const auto& attr : node_def.attr()) { @@ -383,8 +419,9 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { "with your GraphDef-generating binary.)."); } TF_RETURN_WITH_CONTEXT_IF_ERROR( - ValidateAttrValue(attr.second, *iter->second), "; NodeDef: ", - SummarizeNodeDef(node_def), "; ", SummarizeOpDef(op_def)); + ValidateAttrValue(attr.second, *iter->second), + "; NodeDef: ", SummarizeNodeDef(node_def), "; ", + SummarizeOpDef(op_def)); // Keep track of which attr names have (not) been found in the NodeDef. op_attrs.erase(iter); } @@ -431,9 +468,9 @@ Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def, } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) { *num = 1; } else { - return errors::InvalidArgument("Argument '", arg_def.name(), - "' incorrectly specified in op definition: ", - SummarizeOpDef(op_def)); + return errors::InvalidArgument( + "Argument '", arg_def.name(), + "' incorrectly specified in op definition: ", SummarizeOpDef(op_def)); } return Status::OK(); } @@ -465,6 +502,11 @@ Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def, return Status::OK(); } +Status NameRangesForNode(const Node& node, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs) { + return NameRangesForNode(node.def(), op_def, inputs, outputs); +} + void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) { for (const auto& attr_def : op_def.attr()) { AttrSlice attrs(*node_def); @@ -565,4 +607,8 @@ Status AttachDef(const Status& status, const NodeDef& node_def) { return ret; } +Status AttachDef(const Status& status, const Node& node) { + return AttachDef(status, node.def()); +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 018e4d15f2..1438abdec6 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -29,6 +29,8 @@ limitations under the License. namespace tensorflow { +class Node; + // Name of the attribute used to encode node colocation constraints. // // Nodes can be co-located on the same device. Desire for explicit co-location @@ -39,8 +41,9 @@ extern const char* const kColocationAttrName; // String prefix applied to the operation name for colocation constraints. extern const char* const kColocationGroupPrefix; -// Produce a human-readable version of a NodeDef that is more concise +// Produce a human-readable version of a Node or NodeDef that is more concise // than a text-format proto. +string SummarizeNode(const Node& node); string SummarizeNodeDef(const NodeDef& node_def); typedef protobuf::Map AttrValueMap; @@ -78,8 +81,11 @@ class AttrSlice { public: AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit) + AttrSlice(); // Empty explicit AttrSlice(const AttrValueMap* a); + int size() const { return attrs_->size(); } + // Returns the attr with attr_name if found. Otherwise, returns // nullptr. const AttrValue* Find(StringPiece attr_name) const; @@ -88,6 +94,33 @@ class AttrSlice { // NotFound status. Status Find(StringPiece attr_name, const AttrValue** attr_value) const; + // Helper class to avoid allocations in EqualAttrs. + // TODO(irving): Will go away once NodeInfo is used. + struct Scratch { + string a; + string b; + }; + + // Check if all attrs and attr values match. Does not take defaults into + // account. + // + // TODO(irving): There is a bug in this routine inherited from its + // OptimizerCSE::EqualAttrs precedecessor. The same tensor attr can be + // represented in more than one way as an AttrValue, since TensorProto is + // not 1-1. This bug will go away once I replace everything with NodeInfo, + // which stores a Tensor object directly. The Scratch object will also go + // away. + bool EqualAttrs(AttrSlice other, Scratch* scratch) const; + + // If this AttrSlice has an attached NodeDef, summarize it. This is for + // error messages only: we intentionally do not provide direct access to the + // NodeDef, since it is not always there. + string SummarizeNode() const; + + // Iteration over all attrs + AttrValueMap::const_iterator begin() const { return attrs_->begin(); } + AttrValueMap::const_iterator end() const { return attrs_->end(); } + private: const NodeDef* ndef_; const AttrValueMap* attrs_; @@ -183,9 +216,12 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def); // corresponding input/output index range. For example, // input "foo" corresponds to input indices // [ (*inputs)["foo"].first, (*inputs)["foo"].second ). +// TODO(irving): Remove the NodeDef version; keep only the Node version. typedef std::unordered_map> NameRangeMap; Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def, NameRangeMap* inputs, NameRangeMap* outputs); +Status NameRangesForNode(const Node& node, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs); // Adds default values to *node_def for unspecified attrs from op_def. void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def); @@ -206,6 +242,7 @@ Status ValidateExternalNodeDefSyntax(const NodeDef& node_def); // Returns "status" with kernel's NodeDef attached as additional text // in the error message. Status AttachDef(const Status& status, const NodeDef& node_def); +Status AttachDef(const Status& status, const Node& node); } // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 422ee80720..6c3917c686 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -842,13 +843,10 @@ bool InTypeList(DataType dt, const AttrValue& type_list) { return false; } -// Returns whether the attrs in the NodeDef satisfy the constraints in -// the kernel_def. Returns an error if attrs in kernel_def are not -// found, or have a mismatching type. -Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def, - bool* match) { +// Returns whether the attrs satisfy the constraints in the kernel_def. Returns +// an error if attrs in kernel_def are not found, or have a mismatching type. +Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) { *match = false; - AttrSlice attrs(node_def); for (const auto& constraint : kernel_def.constraint()) { if (constraint.allowed_values().list().type_size() == 0) { return errors::Unimplemented( @@ -872,7 +870,7 @@ Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def, "' that has value '", SummarizeAttrValue(*found), "' that does not have type 'type' or 'list(type)' in NodeDef " "'", - SummarizeNodeDef(node_def), "'"); + attrs.SummarizeNode(), "'"); } for (int t : found->list().type()) { @@ -885,7 +883,7 @@ Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def, } else { return errors::InvalidArgument( "OpKernel '", kernel_def.op(), "' has constraint on attr '", - constraint.name(), "' not in NodeDef '", SummarizeNodeDef(node_def), + constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(), "', KernelDef: '", ProtoShortDebugString(kernel_def), "'"); } } @@ -895,6 +893,7 @@ Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def, static const StringPiece kKernelAttr("_kernel"); +// TODO(irving): Replace with const Node& version below. Status FindKernelRegistration(const DeviceType& device_type, const NodeDef& node_def, const KernelRegistration** reg, @@ -927,8 +926,16 @@ Status FindKernelRegistration(const DeviceType& device_type, return Status::OK(); } +Status FindKernelRegistration(const DeviceType& device_type, const Node& node, + const KernelRegistration** reg, + bool* was_attr_mismatch) { + return FindKernelRegistration(device_type, node.def(), reg, + was_attr_mismatch); +} + } // namespace +// TODO(irving): Change const NodeDef& to const Node& Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, const KernelDef** def, string* kernel_class_name) { const KernelRegistration* reg = nullptr; diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index cebadcc5b4..d064a8ec4d 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -184,8 +184,8 @@ class InferenceContext { } #ifndef NDEBUG for (int i = 0; i < num_outputs(); ++i) { - DCHECK(output(i).IsSet()) << i << " for " << node_def().name() - << " of type " << node_def().op(); + DCHECK(output(i).IsSet()) + << i << " for " << node_def_.name() << " of type " << node_def_.op(); } #endif // NDEBUG return s; @@ -394,11 +394,6 @@ class InferenceContext { // the value. Status MakeDimForScalarInput(int idx, DimensionHandle* out); - // Returns the NodeDef. The returned reference does not outlive the - // InferenceContext, and it should not be used after InferenceContext is - // destroyed. - const NodeDef& node_def() { return node_def_; } - // Look up the attr for the NodeDef being evaluated with name attr_name and // set *value to its value. If no attr with attr_name is found in def(), or // the attr does not have a matching type, a non-ok status will be returned. diff --git a/tensorflow/core/graph/control_flow.cc b/tensorflow/core/graph/control_flow.cc index 8409fb4cd0..db6683d1e7 100644 --- a/tensorflow/core/graph/control_flow.cc +++ b/tensorflow/core/graph/control_flow.cc @@ -88,7 +88,7 @@ Status BuildControlFlowInfo(Graph* g, std::vector* info) { out_info->frame = out; out_info->parent_frame = frame; TF_RETURN_IF_ERROR( - GetNodeAttr(out->def(), "frame_name", &out_info->frame_name)); + GetNodeAttr(out->attrs(), "frame_name", &out_info->frame_name)); if (out_info->frame_name.empty()) { return errors::InvalidArgument("The Enter node ", out->name(), " must have a frame name."); diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index d765959ca0..9066de5668 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -78,7 +78,7 @@ string Node::DebugString() const { } else { strings::StrAppend(&ret, " op device:"); strings::StrAppend(&ret, "{", assigned_device_name_, "}"); - strings::StrAppend(&ret, " def:{", SummarizeNodeDef(def()), "}}"); + strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}"); } return ret; } @@ -474,7 +474,7 @@ void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const { for (size_t i = 0; i < inputs.size(); ++i) { const Edge* edge = inputs[i]; if (edge == nullptr) { - node_def->add_input(node->def().input(i)); + node_def->add_input(node->requested_inputs()[i]); } else { const Node* src = edge->src(); if (!src->IsOp()) continue; diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index ac22dfc324..8554cb2f4b 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -71,6 +71,7 @@ class Node { int cost_id() const { return cost_id_; } const string& name() const { return props_->node_def_.name(); } const string& type_string() const { return props_->node_def_.op(); } + // def() provides the NodeDef the user supplied, but the specifics // of this Node may have changed due to placement, optimization, etc. // In particular: @@ -80,6 +81,7 @@ class Node { // * def().device() is the "user's requested device" and may not match // the actual assigned device, see assigned_device_name() below; // * def().attr() is authoritative. + // TODO(irving): Replace with NodeInfo. const NodeDef& def() const { return props_->node_def_; } const OpDef& op_def() const { return *props_->op_def_; } @@ -92,6 +94,10 @@ class Node { DataType output_type(int32 o) const { return props_->output_types_[o]; } const DataTypeVector& output_types() const { return props_->output_types_; } + // The device requested by the user. For the actual assigned device, + // use assigned_device_name() below. + const string& requested_device() const { return def().device(); } + // This gives the device the runtime has assigned this node to. If // you want the device the user requested, use def().device() instead. // TODO(josh11b): Validate that the assigned_device, if not empty: @@ -103,6 +109,14 @@ class Node { assigned_device_name_ = device_name; } + // Read only access to attributes + AttrSlice attrs() const { return AttrSlice(def()); } + + // Inputs requested by the NodeDef. For the actual inputs, use in_edges. + const protobuf::RepeatedPtrField& requested_inputs() const { + return def().input(); + } + // Get the neighboring nodes via edges either in or out of this node. gtl::iterator_range in_nodes() const; gtl::iterator_range out_nodes() const; diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 9d4a0a52f7..70087b8fe1 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -424,7 +424,7 @@ Status GraphConstructor::ValidateShape(Node* node) { // For nodes with the _output_shapes atttribute, override the shape. std::vector shape_attrs; const char* kAttrName = "_output_shapes"; - if (!GetNodeAttr(node->def(), kAttrName, &shape_attrs).ok()) { + if (!GetNodeAttr(node->attrs(), kAttrName, &shape_attrs).ok()) { // No _output_shapes attribute, the AddNode call above was sufficient. return Status::OK(); } @@ -458,7 +458,7 @@ Status GraphConstructor::ValidateShape(Node* node) { // functions that are not critical to correct execution but // would cause graphs to fail if imported after correcting. // - const string& op = node->def().op(); + const string& op = node->type_string(); const std::vector whitelist = { // To be removed after 2017/03/08. "RandomShuffleQueue", "PaddingFIFOQueue", "FIFOQueue", diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index e3b7f322cb..6013b2ff51 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -146,7 +146,7 @@ class GraphConstructorTest : public ::testing::Test { return ""; } std::vector value; - Status s = GetNodeAttr(n->def(), kColocationAttrName, &value); + Status s = GetNodeAttr(n->attrs(), kColocationAttrName, &value); if (!s.ok()) { return ""; } @@ -997,7 +997,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_DefaultAttrs) { } ASSERT_TRUE(a != nullptr); int value = 0; - s = GetNodeAttr(a->def(), "default_int", &value); + s = GetNodeAttr(a->attrs(), "default_int", &value); ASSERT_EQ(Status::OK(), s) << s << " -- " << a->def().DebugString(); EXPECT_EQ(31415, value); } @@ -1201,9 +1201,9 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMap) { // Check that t1's NodeDef is consistent with graph Node* t1 = FindNode("t1"); - ASSERT_EQ(t1->def().input_size(), 2); - ASSERT_EQ(t1->def().input(0), "input:1"); - ASSERT_EQ(t1->def().input(1), "input:0"); + ASSERT_EQ(t1->requested_inputs().size(), 2); + ASSERT_EQ(t1->requested_inputs()[0], "input:1"); + ASSERT_EQ(t1->requested_inputs()[1], "input:0"); } TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) { @@ -1254,19 +1254,19 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) { // Check that NodeDefs are consistent with graph Node* t1 = FindNode("import/t1"); - ASSERT_EQ(t1->def().input_size(), 2); - EXPECT_EQ(t1->def().input(0), "input:0"); - EXPECT_EQ(t1->def().input(1), "input:0"); + ASSERT_EQ(t1->requested_inputs().size(), 2); + EXPECT_EQ(t1->requested_inputs()[0], "input:0"); + EXPECT_EQ(t1->requested_inputs()[1], "input:0"); Node* t2 = FindNode("import/t2"); - ASSERT_EQ(t2->def().input_size(), 2); - EXPECT_EQ(t2->def().input(0), "import/t1:0"); - EXPECT_EQ(t2->def().input(1), "import/t1:0"); + ASSERT_EQ(t2->requested_inputs().size(), 2); + EXPECT_EQ(t2->requested_inputs()[0], "import/t1:0"); + EXPECT_EQ(t2->requested_inputs()[1], "import/t1:0"); Node* t3 = FindNode("import/t3"); - ASSERT_EQ(t3->def().input_size(), 2); - EXPECT_EQ(t3->def().input(0), "import/unmapped_input:0"); - EXPECT_EQ(t3->def().input(1), "import/unmapped_input:1"); + ASSERT_EQ(t3->requested_inputs().size(), 2); + EXPECT_EQ(t3->requested_inputs()[0], "import/unmapped_input:0"); + EXPECT_EQ(t3->requested_inputs()[1], "import/unmapped_input:1"); } TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithControlEdges) { @@ -1795,24 +1795,24 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDeps) { // Test that node defs are consistent with graph Node* w1 = FindNode("import/W1"); - ASSERT_EQ(w1->def().input_size(), 2); - EXPECT_EQ(w1->def().input(0), "^W1"); - EXPECT_EQ(w1->def().input(1), "^W2"); + ASSERT_EQ(w1->requested_inputs().size(), 2); + EXPECT_EQ(w1->requested_inputs()[0], "^W1"); + EXPECT_EQ(w1->requested_inputs()[1], "^W2"); Node* input = FindNode("import/input"); - ASSERT_EQ(input->def().input_size(), 2); - EXPECT_EQ(input->def().input(0), "^W1"); - EXPECT_EQ(input->def().input(1), "^W2"); + ASSERT_EQ(input->requested_inputs().size(), 2); + EXPECT_EQ(input->requested_inputs()[0], "^W1"); + EXPECT_EQ(input->requested_inputs()[1], "^W2"); Node* input2 = FindNode("import/input2"); - ASSERT_EQ(input2->def().input_size(), 2); - EXPECT_EQ(input2->def().input(0), "^W1"); - EXPECT_EQ(input2->def().input(1), "^W2"); + ASSERT_EQ(input2->requested_inputs().size(), 2); + EXPECT_EQ(input2->requested_inputs()[0], "^W1"); + EXPECT_EQ(input2->requested_inputs()[1], "^W2"); Node* t1 = FindNode("import/t1"); - ASSERT_EQ(t1->def().input_size(), 2); - EXPECT_EQ(t1->def().input(0), "import/input:0"); - EXPECT_EQ(t1->def().input(1), "import/input:1"); + ASSERT_EQ(t1->requested_inputs().size(), 2); + EXPECT_EQ(t1->requested_inputs()[0], "import/input:0"); + EXPECT_EQ(t1->requested_inputs()[1], "import/input:1"); } TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) { @@ -1856,15 +1856,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) { // Test that node defs are consistent with graph Node* merge = FindNode("merge"); - ASSERT_EQ(merge->def().input_size(), 3); - EXPECT_EQ(merge->def().input(0), "input:0"); - EXPECT_EQ(merge->def().input(1), "t1:0"); - EXPECT_EQ(merge->def().input(2), "^W1"); + ASSERT_EQ(merge->requested_inputs().size(), 3); + EXPECT_EQ(merge->requested_inputs()[0], "input:0"); + EXPECT_EQ(merge->requested_inputs()[1], "t1:0"); + EXPECT_EQ(merge->requested_inputs()[2], "^W1"); Node* t1 = FindNode("t1"); - ASSERT_EQ(t1->def().input_size(), 2); - EXPECT_EQ(t1->def().input(0), "merge:0"); - EXPECT_EQ(t1->def().input(1), "merge:0"); + ASSERT_EQ(t1->requested_inputs().size(), 2); + EXPECT_EQ(t1->requested_inputs()[0], "merge:0"); + EXPECT_EQ(t1->requested_inputs()[1], "merge:0"); } TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsErrors) { diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index c7ad6a1e77..57a2f399e0 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -356,7 +356,7 @@ string ControlLoopName(const string& name) { } bool IsControlLoop(const Node* node) { - const string& name = node->def().name(); + const string& name = node->name(); return StringPiece(name).starts_with("_cloop"); } @@ -468,7 +468,7 @@ Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src, const string& device_name = edge->dst()->assigned_device_name(); const string& frame_name = src_info.frame_name; int parallel_iterations; - status = GetNodeAttr(src_info.frame->def(), "parallel_iterations", + status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations", ¶llel_iterations); if (!status.ok()) return status; @@ -903,11 +903,11 @@ Status Partition(const PartitionOptions& opts, Graph* g, send_start_time = opts.start_times[src->id()].value(); recv_start_time = opts.start_times[dst->id()].value(); } else { - status = GetNodeAttr(src->def(), "_start_time", &send_start_time); + status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time); if (!status.ok()) { return status; } - status = GetNodeAttr(dst->def(), "_start_time", &recv_start_time); + status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time); if (!status.ok()) { return status; } diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index 4afc878f76..89784c631f 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -318,21 +318,21 @@ TEST_F(GraphTest, AddAttr) { n1->AddAttr("_a", "new_attr"); string attr; - EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_a", &attr)); + EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr)); EXPECT_EQ("new_attr", attr); Node* n2 = graph_.CopyNode(n1); n1->AddAttr("_b", "new_attr_2"); - EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_a", &attr)); + EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr)); EXPECT_EQ("new_attr", attr); - EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_b", &attr)); + EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_b", &attr)); EXPECT_EQ("new_attr_2", attr); - EXPECT_EQ(Status::OK(), GetNodeAttr(n2->def(), "_a", &attr)); + EXPECT_EQ(Status::OK(), GetNodeAttr(n2->attrs(), "_a", &attr)); EXPECT_EQ("new_attr", attr); - EXPECT_NE(Status::OK(), GetNodeAttr(n2->def(), "_b", &attr)); + EXPECT_NE(Status::OK(), GetNodeAttr(n2->attrs(), "_b", &attr)); } // Convert edge iteration results into a sorted string. diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc index a679eac0e7..a22a9b3fa3 100644 --- a/tensorflow/core/graph/optimizer_cse.cc +++ b/tensorflow/core/graph/optimizer_cse.cc @@ -56,11 +56,9 @@ class OptimizerCSE { bool Optimize(const std::function& consider_fn); private: - struct Scratch; - static size_t NodeHash(const Node* n); - static bool Equivalent(const Node* a, const Node* b, Scratch* s); - static bool EqualAttrs(const Node* a, const Node* b, Scratch* s); + static bool Equivalent(const Node* a, const Node* b, + AttrSlice::Scratch* scratch); Graph* g_; }; @@ -110,7 +108,7 @@ size_t OptimizerCSE::NodeHash(const Node* n) { // Hash the attrs. For example, this makes sure different constants // end up in different hash buckets. string tmp; - for (const auto& attr : n->def().attr()) { + for (const auto& attr : n->attrs()) { tmp = attr.first; attr.second.AppendToString(&tmp); // Add hashes of attrs, so the order of attrs doesn't matter. @@ -122,28 +120,6 @@ size_t OptimizerCSE::NodeHash(const Node* n) { return h; } -struct OptimizerCSE::Scratch { - // For EqualAttrs(): - string a; - string b; -}; - -bool OptimizerCSE::EqualAttrs(const Node* a, const Node* b, Scratch* scratch) { - if (a->def().attr_size() != b->def().attr_size()) return false; - - for (const auto& attr : b->def().attr()) { - auto iter = a->def().attr().find(attr.first); - if (iter == a->def().attr().end()) return false; - // Note: it should be safe to compare proto serializations of the attr - // values since at most one field should be set in each (indeed, it - // should be the same field). - iter->second.SerializeToString(&scratch->a); - attr.second.SerializeToString(&scratch->b); - if (scratch->a != scratch->b) return false; - } - return true; -} - static bool HasRefInput(const Node* n) { for (auto dt : n->input_types()) { if (IsRefType(dt)) return true; @@ -151,7 +127,8 @@ static bool HasRefInput(const Node* n) { return false; } -bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) { +bool OptimizerCSE::Equivalent(const Node* a, const Node* b, + AttrSlice::Scratch* scratch) { // Different op names are different if (a->type_string() != b->type_string()) return false; @@ -164,7 +141,7 @@ bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) { // Compare attrs. Note that equal attrs implies equal input and // output types. - if (!EqualAttrs(a, b, scratch)) return false; + if (!a->attrs().EqualAttrs(b->attrs(), scratch)) return false; // Compare input sources if (a->num_inputs() != b->num_inputs()) return false; @@ -206,7 +183,7 @@ bool OptimizerCSE::Optimize( // Scratch space for Equivalent calls. Allocated here and passed in to // Equivalent to avoid allocation inside the loop below. bool changed = false; - Scratch scratch; + AttrSlice::Scratch scratch; for (Node* n : order) { if (!n->IsOp()) continue; diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc index e3ef5e2f0c..4a479d3258 100644 --- a/tensorflow/core/graph/quantize_training.cc +++ b/tensorflow/core/graph/quantize_training.cc @@ -192,9 +192,9 @@ Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op, Tensor tensor_names; Tensor shape_and_slices; TF_RETURN_IF_ERROR( - GetNodeAttr(AttrSlice(tensor_names_op->def()), "value", &tensor_names)); - TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(shape_and_slices_op->def()), "value", - &shape_and_slices)); + GetNodeAttr(tensor_names_op->attrs(), "value", &tensor_names)); + TF_RETURN_IF_ERROR( + GetNodeAttr(shape_and_slices_op->attrs(), "value", &shape_and_slices)); int tn_size = tensor_names.NumElements(); int var_size = added_variables.size(); diff --git a/tensorflow/core/graph/quantize_training_test.cc b/tensorflow/core/graph/quantize_training_test.cc index 9cbb928c11..d817d980de 100644 --- a/tensorflow/core/graph/quantize_training_test.cc +++ b/tensorflow/core/graph/quantize_training_test.cc @@ -112,17 +112,15 @@ TEST_F(QuantizeTrainingTest, SignedInput) { TF_ASSERT_OK( FindNode(g, strings::StrCat(identity->name(), "/QuantizeAndDequantizeV2"), &identity_q_node)); - NodeDef identity_q = identity_q_node->def(); ASSERT_EQ("true", - SummarizeAttrValue(identity_q.attr().find("signed_input")->second)); + SummarizeAttrValue(*identity_q_node->attrs().Find("signed_input"))); // Quantize_and_dequantize node for relu should have signed_input==false. Node* relu_q_node; TF_ASSERT_OK( FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"), &relu_q_node)); - NodeDef relu_q = relu_q_node->def(); ASSERT_EQ("false", - SummarizeAttrValue(relu_q.attr().find("signed_input")->second)); + SummarizeAttrValue(*relu_q_node->attrs().Find("signed_input"))); } TEST_F(QuantizeTrainingTest, RangeGivenTrue) { @@ -165,17 +163,15 @@ TEST_F(QuantizeTrainingTest, RangeGivenTrue) { TF_ASSERT_OK( FindNode(g, strings::StrCat(relu6->name(), "/QuantizeAndDequantizeV2"), &relu6_q_node)); - NodeDef identity_q = relu6_q_node->def(); ASSERT_EQ("true", - SummarizeAttrValue(identity_q.attr().find("range_given")->second)); + SummarizeAttrValue(*relu6_q_node->attrs().Find("range_given"))); // Quantize_and_dequantize node for relu should have range_given==true. Node* relu_q_node; TF_ASSERT_OK( FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"), &relu_q_node)); - NodeDef relu_q = relu_q_node->def(); ASSERT_EQ("true", - SummarizeAttrValue(relu_q.attr().find("range_given")->second)); + SummarizeAttrValue(*relu_q_node->attrs().Find("range_given"))); } TEST_F(QuantizeTrainingTest, WithBackwardNodes_QuantizeAndDequantize) { diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index 9849d9a159..e10b692889 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -106,7 +106,7 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, // Copy the _output_shapes from the original node to the feed node, // if any. std::vector output_shapes; - if (GetNodeAttr(n->def(), "_output_shapes", &output_shapes).ok()) { + if (GetNodeAttr(n->attrs(), "_output_shapes", &output_shapes).ok()) { if (n->num_outputs() != output_shapes.size()) { return errors::InvalidArgument( "FeedInputs: ", t, @@ -129,8 +129,8 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, if (e->src_output() == id.second) { to_remove.emplace_back(e); } else if (e->src_output() == Graph::kControlSlot && - (n->def().op() == "Placeholder" || - n->def().op() == "PlaceholderV2")) { + (n->type_string() == "Placeholder" || + n->type_string() == "PlaceholderV2")) { // When feeding a Placeholder node, any outgoing control edges // will be replaced with a control edge from the replacement // recv_node. diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc index 3dc11b7a16..93dcfd5e33 100644 --- a/tensorflow/core/graph/subgraph_test.cc +++ b/tensorflow/core/graph/subgraph_test.cc @@ -81,7 +81,7 @@ class SubgraphTest : public ::testing::Test { for (const string& s : expected_nodes) { Node* n = FindNode(s); EXPECT_TRUE(n != nullptr) << s; - if (n->def().op() == "_Send" || n->def().op() == "_Recv") { + if (n->type_string() == "_Send" || n->type_string() == "_Recv") { EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s; } } @@ -367,7 +367,7 @@ TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) { for (Node* node : graph()->nodes()) { if (node->name() == "_recv_input_1") { std::vector shapes; - TF_ASSERT_OK(GetNodeAttr(node->def(), "_output_shapes", &shapes)); + TF_ASSERT_OK(GetNodeAttr(node->attrs(), "_output_shapes", &shapes)); ASSERT_EQ(1, shapes.size()); EXPECT_TRUE(PartialTensorShape({23}).IsIdenticalTo(shapes[0])); break; diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 035483ec17..b0e69d44ed 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -151,8 +151,8 @@ Status GraphProperties::InferStatically() { if (!node->assigned_device_name().empty()) { device_names_[node->name()] = node->assigned_device_name(); - } else if (!node->def().device().empty()) { - device_names_[node->name()] = node->def().device(); + } else if (!node->requested_device().empty()) { + device_names_[node->name()] = node->requested_device(); } else { device_names_[node->name()] = "not set"; } diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index ba408f3657..8c3137ece9 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -186,12 +186,12 @@ REGISTER_KERNEL_BUILDER(Name("_ArrayToList") PassOn); #ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNELS(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("_ListToArray").Device(DEVICE_SYCL).TypeConstraint("T"),\ - PassOn); \ - REGISTER_KERNEL_BUILDER( \ - Name("_ArrayToList").Device(DEVICE_SYCL).TypeConstraint("T"),\ +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("_ListToArray").Device(DEVICE_SYCL).TypeConstraint("T"), \ + PassOn); \ + REGISTER_KERNEL_BUILDER( \ + Name("_ArrayToList").Device(DEVICE_SYCL).TypeConstraint("T"), \ PassOn); REGISTER_SYCL_KERNELS(float); @@ -211,7 +211,7 @@ REGISTER_KERNEL_BUILDER(Name("_ArrayToList") .HostMemory("output") .TypeConstraint("T"), PassOn); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL class SymbolicGradientOp : public AsyncOpKernel { public: @@ -227,7 +227,7 @@ class SymbolicGradientOp : public AsyncOpKernel { FunctionLibraryRuntime::Handle handle; OP_REQUIRES_OK_ASYNC( - ctx, lib->Instantiate(kGradientOp, def().attr(), &handle), done); + ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done); FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc index 6f1616eff3..d927ef3efa 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc @@ -140,7 +140,7 @@ Status GraphTransferer::LoadGraphFromProto( std::vector data_types; std::vector shapes; status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( - node->def(), &data_types, &shapes); + node->attrs(), &data_types, &shapes); if (status.ok()) { CHECK(data_types.size() > port); graph_output_node_info.set_dtype(data_types.at(port)); @@ -359,7 +359,7 @@ void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner, const_node_info.add_shape(shape_array[2]); const_node_info.add_shape(shape_array[3]); const TensorProto* proto = nullptr; - TF_CHECK_OK(GetNodeAttr(node.def(), "value", &proto)); + TF_CHECK_OK(GetNodeAttr(node.attrs(), "value", &proto)); Tensor const_tensor; // TODO(b/32704451): Don't just ignore this status! MakeTensorFromProto(*proto, &const_tensor).IgnoreError(); @@ -395,8 +395,9 @@ int GraphTransferer::RegisterConstantShape(const std::vector& shape) { } bool GraphTransferer::HasPaddingAndStrides(const Node& node) { - return node.def().attr().count(PADDING_ATTR_NAME) > 0 && - node.def().attr().count(STRIDES_ATTR_NAME) > 0; + auto attrs = node.attrs(); + return attrs.Find(PADDING_ATTR_NAME) != nullptr && + attrs.Find(STRIDES_ATTR_NAME) != nullptr; } bool GraphTransferer::IsNodeFlattenReshape(const Node& node, @@ -423,7 +424,7 @@ bool GraphTransferer::IsNodeFlattenReshape(const Node& node, } else { std::vector shapes; TF_CHECK_OK(RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( - node.def(), nullptr, &shapes)); + node.attrs(), nullptr, &shapes)); // Number of outputs should be 1 for reshape node. CHECK_EQ(1, shapes.size()); @@ -444,16 +445,16 @@ void GraphTransferer::RegisterNodeWithPaddingAndStrides( CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); const int id = node_name_to_id_cache_map_[node.name()]; shape_inference::InferenceContext* context = shape_refiner.GetContext(&node); - CHECK_GT(node.def().attr().count(PADDING_ATTR_NAME), 0); + CHECK(node.attrs().Find(PADDING_ATTR_NAME)); // TODO(satok): Use context->GetAttr(...) instead? Padding padding; TF_CHECK_OK(context->GetAttr(PADDING_ATTR_NAME, &padding)); - CHECK_GT(node.def().attr().count(STRIDES_ATTR_NAME), 0); + CHECK(node.attrs().Find(STRIDES_ATTR_NAME)); std::vector strides; TF_CHECK_OK(context->GetAttr(STRIDES_ATTR_NAME, &strides)); const int stride_id = RegisterConstantShape(strides); std::vector extra_inputs{stride_id}; - if (node.def().attr().count(KSIZE_ATTR_NAME) > 0) { + if (node.attrs().Find(KSIZE_ATTR_NAME)) { std::vector kernel_sizes; TF_CHECK_OK(context->GetAttr(KSIZE_ATTR_NAME, &kernel_sizes)); const int ksize_id = RegisterConstantShape(kernel_sizes); @@ -597,7 +598,7 @@ void GraphTransferer::AppendNodeOutputParams(const ShapeRefiner& shape_refiner, std::vector shapes; Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( - node.def(), nullptr, &shapes); + node.attrs(), nullptr, &shapes); for (int i = 0; i < node.num_outputs(); ++i) { int data_size = -1; diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index 456611894d..2174098bde 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -333,17 +333,17 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( } /* static */ Status RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( - const NodeDef& node_def, std::vector* data_types, + AttrSlice attrs, std::vector* data_types, std::vector* shapes) { Status status; if (data_types != nullptr) { - status = GetNodeAttr(node_def, ATTR_OUTPUT_DATA_TYPES, data_types); + status = GetNodeAttr(attrs, ATTR_OUTPUT_DATA_TYPES, data_types); } if (!status.ok()) { return status; } if (shapes != nullptr) { - status = GetNodeAttr(node_def, ATTR_OUTPUT_SHAPES, shapes); + status = GetNodeAttr(attrs, ATTR_OUTPUT_SHAPES, shapes); if (status.ok() && data_types != nullptr) { CHECK_EQ(data_types->size(), shapes->size()); } diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h index 9ec189c85f..97b0c2008a 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h @@ -103,7 +103,7 @@ class RemoteFusedGraphExecuteUtils { static Status AddOutputTensorShapeTypeByTensorShapeMap( const TensorShapeMap& tensor_shape_map, NodeDef* node_def); - static Status GetOutputTensorShapeType(const NodeDef& node_def, + static Status GetOutputTensorShapeType(AttrSlice attrs, std::vector* data_types, std::vector* shapes); -- GitLab From 7161c8205794f744653ffe6739a0c4038e1d6cb8 Mon Sep 17 00:00:00 2001 From: Beomsu Kim <123bskim@naver.com> Date: Wed, 17 May 2017 13:35:37 +0900 Subject: [PATCH 696/697] Fix contrib.seq2seq.BeamSearchDecoder #9855 (#9875) * Shielded top_k input with min Fixes #9855. * Fix docstring note about BeamSearch _get_scores * Change next_beam_size to tensor * Calculate number of available beams before cond --- .../seq2seq/python/ops/beam_search_decoder.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index eb494bda4b..c9be517fad 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -497,13 +497,22 @@ def _beam_search_step(time, logits, beam_state, batch_size, beam_width, time = ops.convert_to_tensor(time, name="time") # During the first time step we only consider the initial beam + scores_shape = array_ops.shape(scores) scores_flat = control_flow_ops.cond( time > 0, lambda: array_ops.reshape(scores, [batch_size, -1]), lambda: scores[:, 0]) + num_available_beam = control_flow_ops.cond( + time > 0, + lambda: math_ops.reduce_prod(scores_shape[1:]), + lambda: math_ops.reduce_prod(scores_shape[2:])) # Pick the next beams according to the specified successors function - next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=beam_width) + next_beam_size = math_ops.minimum( + ops.convert_to_tensor( + beam_width, dtype=dtypes.int32, name="beam_width"), + num_available_beam) + next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size) next_beam_scores.set_shape([static_batch_size, beam_width]) word_indices.set_shape([static_batch_size, beam_width]) @@ -561,7 +570,8 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight): """Calculates scores for beam search hypotheses. Args: - log_probs: The log probabilities with shape [batch_size, beam_width]. + log_probs: The log probabilities with shape + `[batch_size, beam_width, vocab_size]`. sequence_lengths: The array of sequence lengths. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. -- GitLab From 513b1e4bab3eee797877a1250fa18e3ae9c349ad Mon Sep 17 00:00:00 2001 From: Corey Wharton Date: Tue, 16 May 2017 21:36:05 -0700 Subject: [PATCH 697/697] Allow tensor as iou_threshold parameter to tf.image.non_max_suppression. (#9887) * Implement NonMaxSuppressionV2 op that allows tensor as iou_threshold parameter. * Move local functions into anonymous namespace. --- .../core/kernels/non_max_suppression_op.cc | 142 +++++++++------ .../kernels/non_max_suppression_op_test.cc | 165 +++++++++++++++++- tensorflow/core/ops/image_ops.cc | 47 ++++- tensorflow/python/ops/image_ops.py | 1 + tensorflow/python/ops/image_ops_impl.py | 1 + 5 files changed, 299 insertions(+), 57 deletions(-) diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index 4d4851c70c..dc95f67ff0 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace tensorflow { +namespace { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -89,6 +90,63 @@ static inline float ComputeIOU(typename TTypes::ConstTensor boxes, return intersection_area / (area_i + area_j - intersection_area); } +void DoNonMaxSuppressionOp(OpKernelContext* context, + const Tensor& boxes, + const Tensor& scores, + const Tensor& max_output_size, + const float iou_threshold) { + OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1, + errors::InvalidArgument("iou_threshold must be in [0, 1]")); + + int num_boxes = 0; + ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes); + if (!context->status().ok()) { + return; + } + + const int output_size = + std::min(max_output_size.scalar()(), num_boxes); + typename TTypes::ConstTensor boxes_data = + boxes.tensor(); + + std::vector scores_data(num_boxes); + std::copy_n(scores.flat().data(), num_boxes, scores_data.begin()); + std::vector sorted_indices; + DecreasingArgSort(scores_data, &sorted_indices); + + std::vector active(num_boxes, true); + std::vector selected; + int num_active = active.size(); + for (int i = 0; i < num_boxes; ++i) { + if (num_active == 0 || selected.size() >= output_size) break; + if (active[i]) { + selected.push_back(sorted_indices[i]); + } else { + continue; + } + for (int j = i + 1; j < num_boxes; ++j) { + if (active[j]) { + float iou = + ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]); + if (iou > iou_threshold) { + active[j] = false; + num_active--; + } + } + } + } + + // Allocate output tensor + Tensor* output = nullptr; + TensorShape output_shape({static_cast(selected.size())}); + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + typename TTypes::Tensor selected_indices_data = + output->tensor(); + std::copy_n(selected.begin(), selected.size(), selected_indices_data.data()); +} + +} // namespace + template class NonMaxSuppressionOp : public OpKernel { public: @@ -98,9 +156,6 @@ class NonMaxSuppressionOp : public OpKernel { } void Compute(OpKernelContext* context) override { - OP_REQUIRES(context, iou_threshold_ >= 0 && iou_threshold_ <= 1, - errors::InvalidArgument("iou_threshold must be in [0, 1]")); - // boxes: [num_boxes, 4] const Tensor& boxes = context->input(0); // scores: [num_boxes] @@ -112,59 +167,48 @@ class NonMaxSuppressionOp : public OpKernel { errors::InvalidArgument("max_output_size must be 0-D, got shape ", max_output_size.shape().DebugString())); - int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes); - if (!context->status().ok()) { - return; - } - - const int output_size = - std::min(max_output_size.scalar()(), num_boxes); - typename TTypes::ConstTensor boxes_data = - boxes.tensor(); - - std::vector scores_data(num_boxes); - std::copy_n(scores.flat().data(), num_boxes, scores_data.begin()); - std::vector sorted_indices; - DecreasingArgSort(scores_data, &sorted_indices); - - std::vector active(num_boxes, true); - std::vector selected; - int num_active = active.size(); - for (int i = 0; i < num_boxes; ++i) { - if (num_active == 0 || selected.size() >= output_size) break; - if (active[i]) { - selected.push_back(sorted_indices[i]); - } else { - continue; - } - for (int j = i + 1; j < num_boxes; ++j) { - if (active[j]) { - float iou = - ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]); - if (iou > iou_threshold_) { - active[j] = false; - num_active--; - } - } - } - } - - // Allocate output tensor - Tensor* output = nullptr; - TensorShape output_shape({static_cast(selected.size())}); - OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - typename TTypes::Tensor selected_indices_data = - output->tensor(); - std::copy_n(selected.begin(), selected.size(), - selected_indices_data.data()); + DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, iou_threshold_); } private: float iou_threshold_; }; +template +class NonMaxSuppressionV2Op : public OpKernel { + public: + explicit NonMaxSuppressionV2Op(OpKernelConstruction* context) + : OpKernel(context) { + } + + void Compute(OpKernelContext* context) override { + // boxes: [num_boxes, 4] + const Tensor& boxes = context->input(0); + // scores: [num_boxes] + const Tensor& scores = context->input(1); + // max_output_size: scalar + const Tensor& max_output_size = context->input(2); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(max_output_size.shape()), + errors::InvalidArgument("max_output_size must be 0-D, got shape ", + max_output_size.shape().DebugString())); + // iou_threshold: scalar + const Tensor& iou_threshold = context->input(3); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(iou_threshold.shape()), + errors::InvalidArgument("iou_threshold must be 0-D, got shape ", + iou_threshold.shape().DebugString())); + + const float iou_threshold_val = iou_threshold.scalar()(); + + DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, iou_threshold_val); + } +}; + REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU), NonMaxSuppressionOp); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU), + NonMaxSuppressionV2Op); + } // namespace tensorflow diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc index 72e368db77..0a075b48b0 100644 --- a/tensorflow/core/kernels/non_max_suppression_op_test.cc +++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc @@ -141,6 +141,161 @@ TEST_F(NonMaxSuppressionOpTest, TestInconsistentBoxAndScoreShapes) { AddInputFromArray(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f}); AddInputFromArray(TensorShape({}), {30}); Status s = RunOpKernel(); + + ASSERT_FALSE(s.ok()); + EXPECT_TRUE( + StringPiece(s.ToString()).contains("scores has incompatible shape")) + << s; +} + +TEST_F(NonMaxSuppressionOpTest, TestInvalidIOUThreshold) { + MakeOp(1.2); + AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); + AddInputFromArray(TensorShape({1}), {.9f}); + AddInputFromArray(TensorShape({}), {3}); + Status s = RunOpKernel(); + + ASSERT_FALSE(s.ok()); + EXPECT_TRUE( + StringPiece(s.ToString()).contains("iou_threshold must be in [0, 1]")) + << s; +} + +TEST_F(NonMaxSuppressionOpTest, TestEmptyInput) { + MakeOp(.5); + AddInputFromArray(TensorShape({0, 4}), {}); + AddInputFromArray(TensorShape({0}), {}); + AddInputFromArray(TensorShape({}), {30}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({0})); + test::FillValues(&expected, {}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +// +// NonMaxSuppressionV2Op Tests +// + +class NonMaxSuppressionV2OpTest : public OpsTestBase { + protected: + void MakeOp() { + TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op", "NonMaxSuppressionV2") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(NonMaxSuppressionV2OpTest, TestSelectFromThreeClusters) { + MakeOp(); + AddInputFromArray(TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {.5f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({3})); + test::FillValues(&expected, {3, 0, 5}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV2OpTest, TestSelectFromThreeClustersFlippedCoordinates) { + MakeOp(); + AddInputFromArray(TensorShape({6, 4}), + {1, 1, 0, 0, 0, 0.1f, 1, 1.1f, 0, .9f, 1, -0.1f, + 0, 10, 1, 11, 1, 10.1f, 0, 11.1f, 1, 101, 0, 100}); + AddInputFromArray(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {.5f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({3})); + test::FillValues(&expected, {3, 0, 5}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV2OpTest, TestSelectAtMostTwoBoxesFromThreeClusters) { + MakeOp(); + AddInputFromArray(TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray(TensorShape({}), {2}); + AddInputFromArray(TensorShape({}), {.5f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({2})); + test::FillValues(&expected, {3, 0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV2OpTest, TestSelectAtMostThirtyBoxesFromThreeClusters) { + MakeOp(); + AddInputFromArray(TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray(TensorShape({}), {30}); + AddInputFromArray(TensorShape({}), {.5f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({3})); + test::FillValues(&expected, {3, 0, 5}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV2OpTest, TestSelectSingleBox) { + MakeOp(); + AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); + AddInputFromArray(TensorShape({1}), {.9f}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {.5f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({1})); + test::FillValues(&expected, {0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV2OpTest, TestSelectFromTenIdenticalBoxes) { + MakeOp(); + + int num_boxes = 10; + std::vector corners(num_boxes * 4); + std::vector scores(num_boxes); + for (int i = 0; i < num_boxes; ++i) { + corners[i * 4 + 0] = 0; + corners[i * 4 + 1] = 0; + corners[i * 4 + 2] = 1; + corners[i * 4 + 3] = 1; + scores[i] = .9; + } + AddInputFromArray(TensorShape({num_boxes, 4}), corners); + AddInputFromArray(TensorShape({num_boxes}), scores); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {.5f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({1})); + test::FillValues(&expected, {0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV2OpTest, TestInconsistentBoxAndScoreShapes) { + MakeOp(); + AddInputFromArray(TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f}); + AddInputFromArray(TensorShape({}), {30}); + AddInputFromArray(TensorShape({}), {.5f}); + Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE( @@ -148,11 +303,12 @@ TEST_F(NonMaxSuppressionOpTest, TestInconsistentBoxAndScoreShapes) { << s; } -TEST_F(NonMaxSuppressionOpTest, TestInvalidIOUThreshold) { - MakeOp(1.2); +TEST_F(NonMaxSuppressionV2OpTest, TestInvalidIOUThreshold) { + MakeOp(); AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); AddInputFromArray(TensorShape({1}), {.9f}); AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {1.2f}); Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); @@ -161,11 +317,12 @@ TEST_F(NonMaxSuppressionOpTest, TestInvalidIOUThreshold) { << s; } -TEST_F(NonMaxSuppressionOpTest, TestEmptyInput) { - MakeOp(.5); +TEST_F(NonMaxSuppressionV2OpTest, TestEmptyInput) { + MakeOp(); AddInputFromArray(TensorShape({0, 4}), {}); AddInputFromArray(TensorShape({0}), {}); AddInputFromArray(TensorShape({}), {30}); + AddInputFromArray(TensorShape({}), {.5f}); TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_INT32, TensorShape({0})); diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index cfcf2d1ff9..bbfdb34758 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -972,11 +972,50 @@ method: A string specifying the interpolation method. Only 'bilinear' is // -------------------------------------------------------------------------- REGISTER_OP("NonMaxSuppression") + .Input("boxes: float") + .Input("scores: float") + .Input("max_output_size: int32") + .Output("selected_indices: int32") + .Attr("iou_threshold: float = 0.5") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(c->UnknownDim())); + return Status::OK(); + }) + .Doc(R"doc( +Greedily selects a subset of bounding boxes in descending order of score, +pruning away boxes that have high intersection-over-union (IOU) overlap +with previously selected boxes. Bounding boxes are supplied as +[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +diagonal pair of box corners and the coordinates can be provided as normalized +(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +is agnostic to where the origin is in the coordinate system. Note that this +algorithm is invariant to orthogonal transformations and translations +of the coordinate system; thus translating or reflections of the coordinate +system result in the same boxes being selected by the algorithm. +The output of this operation is a set of integers indexing into the input +collection of bounding boxes representing the selected boxes. The bounding +box coordinates corresponding to the selected indices can then be obtained +using the `tf.gather operation`. For example: + selected_indices = tf.image.non_max_suppression( + boxes, scores, max_output_size, iou_threshold) + selected_boxes = tf.gather(boxes, selected_indices) +boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +scores: A 1-D float tensor of shape `[num_boxes]` representing a single + score corresponding to each box (each row of boxes). +max_output_size: A scalar integer tensor representing the maximum number of + boxes to be selected by non max suppression. +iou_threshold: A float representing the threshold for deciding whether boxes + overlap too much with respect to IOU. +selected_indices: A 1-D integer tensor of shape `[M]` representing the selected + indices from the boxes tensor, where `M <= max_output_size`. +)doc"); + +REGISTER_OP("NonMaxSuppressionV2") .Input("boxes: float") .Input("scores: float") .Input("max_output_size: int32") + .Input("iou_threshold: float") .Output("selected_indices: int32") - .Attr("iou_threshold: float = 0.5") .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->Vector(c->UnknownDim())); return Status::OK(); @@ -998,7 +1037,7 @@ collection of bounding boxes representing the selected boxes. The bounding box coordinates corresponding to the selected indices can then be obtained using the `tf.gather operation`. For example: - selected_indices = tf.image.non_max_suppression( + selected_indices = tf.image.non_max_suppression_v2( boxes, scores, max_output_size, iou_threshold) selected_boxes = tf.gather(boxes, selected_indices) @@ -1007,8 +1046,8 @@ scores: A 1-D float tensor of shape `[num_boxes]` representing a single score corresponding to each box (each row of boxes). max_output_size: A scalar integer tensor representing the maximum number of boxes to be selected by non max suppression. -iou_threshold: A float representing the threshold for deciding whether boxes - overlap too much with respect to IOU. +iou_threshold: A 0-D float tensor representing the threshold for deciding whether + boxes overlap too much with respect to IOU. selected_indices: A 1-D integer tensor of shape `[M]` representing the selected indices from the boxes tensor, where `M <= max_output_size`. )doc"); diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index c29ae26f04..2aad1e1519 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -59,6 +59,7 @@ See the @{$python/image} guide. @@per_image_standardization @@draw_bounding_boxes @@non_max_suppression +@@non_max_suppression_v2 @@sample_distorted_bounding_box @@total_variation """ diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 3e140ce047..ae7999a71e 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -52,6 +52,7 @@ ops.NotDifferentiable('SampleDistortedBoundingBox') # latent bugs here. ops.NotDifferentiable('ExtractGlimpse') ops.NotDifferentiable('NonMaxSuppression') +ops.NotDifferentiable('NonMaxSuppressionV2') def _assert(cond, ex_type, msg): -- GitLab